BenjiELCA commited on
Commit
64b088f
·
1 Parent(s): 210f0e9

add commentary to all the code

Browse files
Files changed (9) hide show
  1. app.py +26 -11
  2. modules/OCR.py +143 -41
  3. modules/dataset_loader.py +166 -60
  4. modules/eval.py +391 -94
  5. modules/streamlit_utils.py +209 -48
  6. modules/toWizard.py +95 -14
  7. modules/toXML.py +330 -64
  8. modules/train.py +265 -240
  9. modules/utils.py +79 -196
app.py CHANGED
@@ -6,75 +6,90 @@ import numpy as np
6
  from modules.streamlit_utils import *
7
  from modules.utils import error
8
 
9
-
10
  def main():
11
- # Example usage
 
 
 
 
12
  if 'model_loaded' not in st.session_state:
13
  st.session_state.model_loaded = False
14
 
15
  st.session_state.first_run = True
 
 
16
  is_mobile, screen_width = configure_page()
 
 
17
  display_banner(is_mobile)
18
  display_title(is_mobile)
19
  display_sidebar()
20
 
 
21
  initialize_session_state()
22
 
23
  cropped_image = None
24
 
 
25
  img_selected = load_example_image()
26
  uploaded_file = load_user_image(img_selected, is_mobile)
27
 
 
28
  if uploaded_file is not None:
29
  cropped_image = display_image(uploaded_file, screen_width, is_mobile)
30
 
 
31
  if uploaded_file is not None:
32
  get_score_threshold(is_mobile)
33
 
 
34
  if st.button("🚀 Launch Prediction"):
35
  st.session_state.image = launch_prediction(cropped_image, st.session_state.score_threshold, is_mobile, screen_width)
36
  st.session_state.original_prediction = st.session_state.prediction.copy()
37
  st.rerun()
38
 
39
- # Create placeholders for all sections
40
  prediction_result_placeholder = st.empty()
41
  additional_options_placeholder = st.empty()
42
  modeler_placeholder = st.empty()
43
 
44
-
45
  if 'prediction' in st.session_state and uploaded_file:
46
  if st.session_state.image != cropped_image:
47
  print('Image has changed')
48
- # Delete the prediction
49
  del st.session_state.prediction
50
  return
51
 
52
- if len(st.session_state.prediction['labels'])==0:
53
- error("No prediction available. Please upload a BPMN image or decrease the detection score treshold.")
54
  else:
55
  with prediction_result_placeholder.container():
56
  if is_mobile:
57
- display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6*screen_width))
58
  else:
59
  with st.expander("Show result of prediction"):
60
- display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6*screen_width))
61
 
 
62
  if not is_mobile:
63
  with additional_options_placeholder.container():
64
  state = modify_results()
65
 
66
-
67
  with modeler_placeholder.container():
68
  modeler_options(is_mobile)
69
  display_bpmn_modeler(is_mobile, screen_width)
70
  else:
 
71
  prediction_result_placeholder.empty()
72
  additional_options_placeholder.empty()
73
  modeler_placeholder.empty()
74
- # Create a lot of space for scrolling
75
  for _ in range(50):
76
  st.text("")
77
 
 
78
  gc.collect()
79
 
80
  if __name__ == "__main__":
 
6
  from modules.streamlit_utils import *
7
  from modules.utils import error
8
 
 
9
  def main():
10
+ """
11
+ Main function to run the Streamlit application for BPMN AI model recognition.
12
+ """
13
+
14
+ # Check if the model is loaded in the session state
15
  if 'model_loaded' not in st.session_state:
16
  st.session_state.model_loaded = False
17
 
18
  st.session_state.first_run = True
19
+
20
+ # Configure the Streamlit page and retrieve screen details
21
  is_mobile, screen_width = configure_page()
22
+
23
+ # Display various UI components
24
  display_banner(is_mobile)
25
  display_title(is_mobile)
26
  display_sidebar()
27
 
28
+ # Initialize session state variables
29
  initialize_session_state()
30
 
31
  cropped_image = None
32
 
33
+ # Load example or user-uploaded image
34
  img_selected = load_example_image()
35
  uploaded_file = load_user_image(img_selected, is_mobile)
36
 
37
+ # Display the uploaded image and allow cropping
38
  if uploaded_file is not None:
39
  cropped_image = display_image(uploaded_file, screen_width, is_mobile)
40
 
41
+ # Set score threshold for prediction if an image is uploaded
42
  if uploaded_file is not None:
43
  get_score_threshold(is_mobile)
44
 
45
+ # Launch prediction when the button is clicked
46
  if st.button("🚀 Launch Prediction"):
47
  st.session_state.image = launch_prediction(cropped_image, st.session_state.score_threshold, is_mobile, screen_width)
48
  st.session_state.original_prediction = st.session_state.prediction.copy()
49
  st.rerun()
50
 
51
+ # Create placeholders for different sections of the UI
52
  prediction_result_placeholder = st.empty()
53
  additional_options_placeholder = st.empty()
54
  modeler_placeholder = st.empty()
55
 
56
+ # Display prediction results and options if predictions are available
57
  if 'prediction' in st.session_state and uploaded_file:
58
  if st.session_state.image != cropped_image:
59
  print('Image has changed')
60
+ # Delete the prediction if the image has changed
61
  del st.session_state.prediction
62
  return
63
 
64
+ if len(st.session_state.prediction['labels']) == 0:
65
+ error("No prediction available. Please upload a BPMN image or decrease the detection score threshold.")
66
  else:
67
  with prediction_result_placeholder.container():
68
  if is_mobile:
69
+ display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6 * screen_width))
70
  else:
71
  with st.expander("Show result of prediction"):
72
+ display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6 * screen_width))
73
 
74
+ # Provide additional options for modification if not on mobile
75
  if not is_mobile:
76
  with additional_options_placeholder.container():
77
  state = modify_results()
78
 
79
+ # Display BPMN modeler options and result
80
  with modeler_placeholder.container():
81
  modeler_options(is_mobile)
82
  display_bpmn_modeler(is_mobile, screen_width)
83
  else:
84
+ # Clear placeholders if no predictions are available
85
  prediction_result_placeholder.empty()
86
  additional_options_placeholder.empty()
87
  modeler_placeholder.empty()
88
+ # Create space for scrolling
89
  for _ in range(50):
90
  st.text("")
91
 
92
+ # Force garbage collection
93
  gc.collect()
94
 
95
  if __name__ == "__main__":
modules/OCR.py CHANGED
@@ -3,13 +3,14 @@ import os
3
  from azure.ai.vision.imageanalysis import ImageAnalysisClient
4
  from azure.ai.vision.imageanalysis.models import VisualFeatures
5
  from azure.core.credentials import AzureKeyCredential
6
- import time
7
  import numpy as np
8
  import networkx as nx
9
  from modules.utils import class_dict, proportion_inside
10
  import json
11
  from modules.utils import rescale_boxes as rescale, is_vertical
12
- import streamlit as st
 
 
13
 
14
  VISION_KEY = os.getenv("VISION_KEY")
15
  VISION_ENDPOINT = os.getenv("VISION_ENDPOINT")
@@ -20,15 +21,17 @@ VISION_ENDPOINT = os.getenv("VISION_ENDPOINT")
20
 
21
  VISION_KEY = json_data["VISION_KEY"]
22
  VISION_ENDPOINT = json_data["VISION_ENDPOINT"]"""
23
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
24
- import torch
25
- import logging
26
 
27
  # Suppress specific warnings from transformers
28
  logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
29
 
30
  # Function to initialize the model and tokenizer
31
  def initialize_model():
 
 
 
32
  tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
33
  model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
34
  return tokenizer, model
@@ -38,6 +41,17 @@ tokenizer, emotion_model = initialize_model()
38
 
39
  # Function to perform sentiment analysis and return the highest scoring emotion and its score between positive and negative
40
  def analyze_sentiment(sentence, tokenizer=tokenizer, model=emotion_model):
 
 
 
 
 
 
 
 
 
 
 
41
  inputs = tokenizer(sentence, return_tensors="pt")
42
  outputs = model(**inputs)
43
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().tolist()
@@ -51,8 +65,16 @@ def analyze_sentiment(sentence, tokenizer=tokenizer, model=emotion_model):
51
  return highest_emotion, highest_score
52
 
53
  def sample_ocr_image_file(image_data):
54
- # Set the values of your computer vision endpoint and computer vision key
55
- # as environment variables:
 
 
 
 
 
 
 
 
56
  try:
57
  endpoint = VISION_ENDPOINT
58
  key = VISION_KEY
@@ -77,16 +99,35 @@ def sample_ocr_image_file(image_data):
77
 
78
 
79
  def text_prediction(image):
80
- #transform the image into a byte array
 
 
 
 
 
 
 
 
 
81
  image.save('temp.jpg')
82
  with open('temp.jpg', 'rb') as f:
83
  image_data = f.read()
84
  ocr_result = sample_ocr_image_file(image_data)
85
- #delete the temporary image
86
  os.remove('temp.jpg')
87
  return ocr_result
88
 
89
  def filter_text(ocr_result, threshold=0.5):
 
 
 
 
 
 
 
 
 
 
90
  words_to_cancel = {"-","--","---","+",".",",","#","@","!","?","(",")","[","]","{","}","<",">","/","\\","|","-","_","=","&","^","%","$","£","€","¥","¢","¤","§","©","®","™","°","±","×","÷","¶","∆","∏","∑","∞","√","∫","≈","≠","≤","≥","≡","∼"}
91
  # Add every other one-letter word to the list of words to cancel, except 'I' and 'a'
92
  for letter in "bcdefghjklmnopqrstuvwxyz1234567890": # All lowercase letters except 'a'
@@ -132,10 +173,16 @@ def filter_text(ocr_result, threshold=0.5):
132
  return list_of_lines
133
 
134
 
 
 
 
135
 
 
 
136
 
137
- def get_box_points(box):
138
- """Returns all critical points of a box: corners and midpoints of edges."""
 
139
  xmin, ymin, xmax, ymax = box
140
  return np.array([
141
  [xmin, ymin], # Bottom-left corner
@@ -149,7 +196,16 @@ def get_box_points(box):
149
  ])
150
 
151
  def min_distance_between_boxes(box1, box2):
152
- """Computes the minimum distance between two boxes considering all critical points."""
 
 
 
 
 
 
 
 
 
153
  points1 = get_box_points(box1)
154
  points2 = get_box_points(box2)
155
 
@@ -162,7 +218,17 @@ def min_distance_between_boxes(box1, box2):
162
  return min_dist
163
 
164
  def are_close(box1, box2, threshold=50):
165
- """Determines if boxes are close based on their corners and center points."""
 
 
 
 
 
 
 
 
 
 
166
  corners1 = np.array([
167
  [box1[0], box1[1]], [box1[0], box1[3]], [box1[2], box1[1]], [box1[2], box1[3]],
168
  [(box1[0]+box1[2])/2, box1[1]], [(box1[0]+box1[2])/2, box1[3]],
@@ -180,13 +246,25 @@ def are_close(box1, box2, threshold=50):
180
  return False
181
 
182
  def find_closest_box(text_box, all_boxes, labels, threshold, iou_threshold=0.5):
183
- """Find the closest box to the given text box within a specified threshold."""
 
 
 
 
 
 
 
 
 
 
 
 
184
  min_distance = float('inf')
185
  closest_index = None
186
 
187
- #check if the text is inside a sequenceFlow
188
  for j in range(len(all_boxes)):
189
- if proportion_inside(text_box, all_boxes[j])>iou_threshold and labels[j] == list(class_dict.values()).index('sequenceFlow'):
190
  return j
191
 
192
  for i, box in enumerate(all_boxes):
@@ -209,20 +287,32 @@ def find_closest_box(text_box, all_boxes, labels, threshold, iou_threshold=0.5):
209
  return None
210
 
211
 
212
-
213
  def group_texts(task_boxes, text_boxes, texts, min_dist=50, iou_threshold=0.8, percentage_thresh=0.8):
214
- """Maps text boxes to task boxes and groups texts within each task based on proximity."""
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  G = nx.Graph()
216
 
217
  # Map each text box to the nearest task box
218
  task_to_texts = {i: [] for i in range(len(task_boxes))}
219
- information_texts = [] # texts not inside any task box
220
  text_to_task_mapped = [False] * len(text_boxes)
221
 
222
  for idx, text_box in enumerate(text_boxes):
223
  mapped = False
224
  for jdx, task_box in enumerate(task_boxes):
225
- if proportion_inside(text_box, task_box)>iou_threshold:
226
  task_to_texts[jdx].append(idx)
227
  text_to_task_mapped[idx] = True
228
  mapped = True
@@ -326,32 +416,45 @@ def group_texts(task_boxes, text_boxes, texts, min_dist=50, iou_threshold=0.8, p
326
  return all_grouped_texts, sentence_boxes, information_grouped_texts, info_sentence_boxes
327
 
328
 
329
- def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0.6,scale=1.0, iou_threshold=0.5):
 
 
330
 
 
 
 
 
 
 
 
 
 
 
 
331
  boxes = rescale(scale, full_pred['boxes'])
332
 
333
  min_dist = 200
334
  labels = full_pred['labels']
335
  avoid = [list(class_dict.values()).index('pool'), list(class_dict.values()).index('lane'), list(class_dict.values()).index('sequenceFlow'), list(class_dict.values()).index('messageFlow'), list(class_dict.values()).index('dataAssociation')]
336
  for i in range(len(boxes)):
337
- box1 = boxes[i]
338
- if labels[i] in avoid:
 
 
 
 
339
  continue
340
- for j in range(i + 1, len(boxes)):
341
- box2 = boxes[j]
342
- if labels[j] in avoid:
343
- continue
344
- dist = min_distance_between_boxes(box1, box2)
345
- min_dist = min(min_dist, dist)
346
 
347
- #print("Minimum distance between boxes:", min_dist)
 
348
 
349
  text_pred[0] = rescale(scale, text_pred[0])
350
  task_boxes = [box for i, box in enumerate(boxes) if full_pred['labels'][i] == list(class_dict.values()).index('task')]
351
  grouped_sentences, sentence_bounding_boxes, info_texts, info_boxes = group_texts(task_boxes, text_pred[0], text_pred[1], min_dist=min_dist)
352
  BPMN_id = set(full_pred['BPMN_id']) # This ensures uniqueness of task names
353
  text_mapping = {id: '' for id in BPMN_id}
354
-
355
 
356
  if print_sentences:
357
  for sentence, box in zip(grouped_sentences, sentence_bounding_boxes):
@@ -363,8 +466,8 @@ def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0
363
  # Map the grouped sentences to the corresponding task
364
  for i in range(len(sentence_bounding_boxes)):
365
  for j in range(len(boxes)):
366
- if proportion_inside(sentence_bounding_boxes[i], boxes[j])>iou_threshold and full_pred['labels'][j] == list(class_dict.values()).index('task'):
367
- text_mapping[full_pred['BPMN_id'][j]]=grouped_sentences[i]
368
 
369
  # Map the grouped sentences to the corresponding pool
370
  for key, elements in full_pred['pool_dict'].items():
@@ -372,17 +475,16 @@ def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0
372
  continue
373
  else:
374
  for i in range(len(info_boxes)):
375
- #find the position of the key in BPMN_id
376
  position = list(full_pred['BPMN_id']).index(key)
377
- if proportion_inside(info_boxes[i], boxes[position])>iou_threshold:
378
  text_mapping[key] = info_texts[i]
379
  info_texts[i] = '' # Clear the text to avoid re-use
380
 
381
-
382
  for i in range(len(info_boxes)):
383
  if is_vertical(info_boxes[i]):
384
  for j in range(len(boxes)):
385
- if proportion_inside(info_boxes[i], boxes[j])>0 and full_pred['labels'][j] == list(class_dict.values()).index('pool'):
386
  print("Text:", info_texts[i], "associate with ", full_pred['BPMN_id'][j])
387
  bpmn_id = full_pred['BPMN_id'][j]
388
  # Append new text or create new entry if not existing
@@ -399,10 +501,10 @@ def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0
399
  for j in range(len(boxes)):
400
  if info_texts[i] == '':
401
  continue # Skip if there's no text
402
- if (proportion_inside(info_boxes[i], boxes[j])>0 or are_close(info_boxes[i], boxes[j], threshold=percentage_thresh*min_dist)) and (full_pred['labels'][j] == list(class_dict.values()).index('event')
403
  or full_pred['labels'][j] == list(class_dict.values()).index('messageEvent')
404
  or full_pred['labels'][j] == list(class_dict.values()).index('timerEvent')
405
- or full_pred['labels'][j] == list(class_dict.values()).index('dataObject')) :
406
  bpmn_id = full_pred['BPMN_id'][j]
407
  # Append new text or create new entry if not existing
408
  if bpmn_id in text_mapping:
@@ -416,7 +518,7 @@ def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0
416
  if info_texts[i] == '' or is_vertical(info_boxes[i]):
417
  continue # Skip if there's no text
418
  # Find the closest box within the defined threshold
419
- closest_index = find_closest_box(info_boxes[i], boxes, full_pred['labels'], threshold=4*min_dist)
420
  if closest_index is not None and (full_pred['labels'][closest_index] == list(class_dict.values()).index('sequenceFlow') or full_pred['labels'][closest_index] == list(class_dict.values()).index('messageFlow')):
421
  bpmn_id = full_pred['BPMN_id'][closest_index]
422
  # Append new text or create new entry if not existing
@@ -430,4 +532,4 @@ def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0
430
  print("Text Mapping:", text_mapping)
431
  print("Information Texts left:", info_texts)
432
 
433
- return text_mapping
 
3
  from azure.ai.vision.imageanalysis import ImageAnalysisClient
4
  from azure.ai.vision.imageanalysis.models import VisualFeatures
5
  from azure.core.credentials import AzureKeyCredential
 
6
  import numpy as np
7
  import networkx as nx
8
  from modules.utils import class_dict, proportion_inside
9
  import json
10
  from modules.utils import rescale_boxes as rescale, is_vertical
11
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
+ import torch
13
+ import logging
14
 
15
  VISION_KEY = os.getenv("VISION_KEY")
16
  VISION_ENDPOINT = os.getenv("VISION_ENDPOINT")
 
21
 
22
  VISION_KEY = json_data["VISION_KEY"]
23
  VISION_ENDPOINT = json_data["VISION_ENDPOINT"]"""
24
+
25
+
 
26
 
27
  # Suppress specific warnings from transformers
28
  logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
29
 
30
  # Function to initialize the model and tokenizer
31
  def initialize_model():
32
+ """
33
+ Initialize the tokenizer and model for sentiment analysis.
34
+ """
35
  tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
36
  model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
37
  return tokenizer, model
 
41
 
42
  # Function to perform sentiment analysis and return the highest scoring emotion and its score between positive and negative
43
  def analyze_sentiment(sentence, tokenizer=tokenizer, model=emotion_model):
44
+ """
45
+ Analyze the sentiment of a given sentence using the initialized tokenizer and model.
46
+
47
+ Parameters:
48
+ - sentence (str): The input sentence to analyze.
49
+ - tokenizer (AutoTokenizer): The tokenizer for processing the sentence.
50
+ - model (AutoModelForSequenceClassification): The model for sentiment analysis.
51
+
52
+ Returns:
53
+ - tuple: The highest scoring emotion ('positive' or 'negative') and its corresponding score.
54
+ """
55
  inputs = tokenizer(sentence, return_tensors="pt")
56
  outputs = model(**inputs)
57
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().tolist()
 
65
  return highest_emotion, highest_score
66
 
67
  def sample_ocr_image_file(image_data):
68
+ """
69
+ Sample OCR function to analyze an image file and extract text using Azure's Computer Vision service.
70
+
71
+ Parameters:
72
+ - image_data (bytes): The image data in bytes.
73
+
74
+ Returns:
75
+ - result: The OCR result from the Computer Vision service.
76
+ """
77
+ # Set the values of your computer vision endpoint and computer vision key as environment variables:
78
  try:
79
  endpoint = VISION_ENDPOINT
80
  key = VISION_KEY
 
99
 
100
 
101
  def text_prediction(image):
102
+ """
103
+ Perform OCR on an image to extract text.
104
+
105
+ Parameters:
106
+ - image: The image to process.
107
+
108
+ Returns:
109
+ - ocr_result: The OCR result.
110
+ """
111
+ # Transform the image into a byte array
112
  image.save('temp.jpg')
113
  with open('temp.jpg', 'rb') as f:
114
  image_data = f.read()
115
  ocr_result = sample_ocr_image_file(image_data)
116
+ # Delete the temporary image
117
  os.remove('temp.jpg')
118
  return ocr_result
119
 
120
  def filter_text(ocr_result, threshold=0.5):
121
+ """
122
+ Filter and process the OCR results to remove unwanted characters and low-confidence words.
123
+
124
+ Parameters:
125
+ - ocr_result: The OCR result.
126
+ - threshold (float): The confidence threshold for filtering words.
127
+
128
+ Returns:
129
+ - list_of_lines: Processed text lines and their bounding boxes.
130
+ """
131
  words_to_cancel = {"-","--","---","+",".",",","#","@","!","?","(",")","[","]","{","}","<",">","/","\\","|","-","_","=","&","^","%","$","£","€","¥","¢","¤","§","©","®","™","°","±","×","÷","¶","∆","∏","∑","∞","√","∫","≈","≠","≤","≥","≡","∼"}
132
  # Add every other one-letter word to the list of words to cancel, except 'I' and 'a'
133
  for letter in "bcdefghjklmnopqrstuvwxyz1234567890": # All lowercase letters except 'a'
 
173
  return list_of_lines
174
 
175
 
176
+ def get_box_points(box):
177
+ """
178
+ Returns all critical points of a box: corners and midpoints of edges.
179
 
180
+ Parameters:
181
+ - box (array): Bounding box coordinates [xmin, ymin, xmax, ymax].
182
 
183
+ Returns:
184
+ - numpy.array: Array of critical points.
185
+ """
186
  xmin, ymin, xmax, ymax = box
187
  return np.array([
188
  [xmin, ymin], # Bottom-left corner
 
196
  ])
197
 
198
  def min_distance_between_boxes(box1, box2):
199
+ """
200
+ Computes the minimum distance between two boxes considering all critical points.
201
+
202
+ Parameters:
203
+ - box1 (array): First bounding box coordinates.
204
+ - box2 (array): Second bounding box coordinates.
205
+
206
+ Returns:
207
+ - float: The minimum distance between the two boxes.
208
+ """
209
  points1 = get_box_points(box1)
210
  points2 = get_box_points(box2)
211
 
 
218
  return min_dist
219
 
220
  def are_close(box1, box2, threshold=50):
221
+ """
222
+ Determines if boxes are close based on their corners and center points.
223
+
224
+ Parameters:
225
+ - box1 (array): First bounding box coordinates.
226
+ - box2 (array): Second bounding box coordinates.
227
+ - threshold (int): Distance threshold for determining closeness.
228
+
229
+ Returns:
230
+ - bool: True if boxes are close, otherwise False.
231
+ """
232
  corners1 = np.array([
233
  [box1[0], box1[1]], [box1[0], box1[3]], [box1[2], box1[1]], [box1[2], box1[3]],
234
  [(box1[0]+box1[2])/2, box1[1]], [(box1[0]+box1[2])/2, box1[3]],
 
246
  return False
247
 
248
  def find_closest_box(text_box, all_boxes, labels, threshold, iou_threshold=0.5):
249
+ """
250
+ Find the closest box to the given text box within a specified threshold.
251
+
252
+ Parameters:
253
+ - text_box (array): The text box coordinates.
254
+ - all_boxes (list): List of all bounding boxes.
255
+ - labels (list): List of labels corresponding to the boxes.
256
+ - threshold (float): Distance threshold for determining closeness.
257
+ - iou_threshold (float): IoU threshold for determining if a text is inside a sequenceFlow.
258
+
259
+ Returns:
260
+ - int or None: Index of the closest box or None if no box is close enough.
261
+ """
262
  min_distance = float('inf')
263
  closest_index = None
264
 
265
+ # Check if the text is inside a sequenceFlow
266
  for j in range(len(all_boxes)):
267
+ if proportion_inside(text_box, all_boxes[j]) > iou_threshold and labels[j] == list(class_dict.values()).index('sequenceFlow'):
268
  return j
269
 
270
  for i, box in enumerate(all_boxes):
 
287
  return None
288
 
289
 
 
290
  def group_texts(task_boxes, text_boxes, texts, min_dist=50, iou_threshold=0.8, percentage_thresh=0.8):
291
+ """
292
+ Maps text boxes to task boxes and groups texts within each task based on proximity.
293
+
294
+ Parameters:
295
+ - task_boxes (list): List of task bounding boxes.
296
+ - text_boxes (list): List of text bounding boxes.
297
+ - texts (list): List of texts corresponding to the text boxes.
298
+ - min_dist (float): Minimum distance threshold for grouping.
299
+ - iou_threshold (float): IoU threshold for determining if text is inside a task box.
300
+ - percentage_thresh (float): Percentage threshold for determining if text boxes are close.
301
+
302
+ Returns:
303
+ - tuple: Grouped task-related texts, their bounding boxes, grouped information texts, and their bounding boxes.
304
+ """
305
  G = nx.Graph()
306
 
307
  # Map each text box to the nearest task box
308
  task_to_texts = {i: [] for i in range(len(task_boxes))}
309
+ information_texts = [] # Texts not inside any task box
310
  text_to_task_mapped = [False] * len(text_boxes)
311
 
312
  for idx, text_box in enumerate(text_boxes):
313
  mapped = False
314
  for jdx, task_box in enumerate(task_boxes):
315
+ if proportion_inside(text_box, task_box) > iou_threshold:
316
  task_to_texts[jdx].append(idx)
317
  text_to_task_mapped[idx] = True
318
  mapped = True
 
416
  return all_grouped_texts, sentence_boxes, information_grouped_texts, info_sentence_boxes
417
 
418
 
419
+ def mapping_text(full_pred, text_pred, print_sentences=False, percentage_thresh=0.6, scale=1.0, iou_threshold=0.5):
420
+ """
421
+ Map the extracted texts to the predicted bounding boxes.
422
 
423
+ Parameters:
424
+ - full_pred (dict): Full prediction dictionary containing boxes, labels, BPMN IDs, and pool dictionary.
425
+ - text_pred (list): List containing text predictions and their bounding boxes.
426
+ - print_sentences (bool): Whether to print the sentences and their bounding boxes.
427
+ - percentage_thresh (float): Percentage threshold for determining closeness.
428
+ - scale (float): Scale factor for rescaling bounding boxes.
429
+ - iou_threshold (float): IoU threshold for determining if text is inside a bounding box.
430
+
431
+ Returns:
432
+ - dict: Text mapping for BPMN elements.
433
+ """
434
  boxes = rescale(scale, full_pred['boxes'])
435
 
436
  min_dist = 200
437
  labels = full_pred['labels']
438
  avoid = [list(class_dict.values()).index('pool'), list(class_dict.values()).index('lane'), list(class_dict.values()).index('sequenceFlow'), list(class_dict.values()).index('messageFlow'), list(class_dict.values()).index('dataAssociation')]
439
  for i in range(len(boxes)):
440
+ box1 = boxes[i]
441
+ if labels[i] in avoid:
442
+ continue
443
+ for j in range(i + 1, len(boxes)):
444
+ box2 = boxes[j]
445
+ if labels[j] in avoid:
446
  continue
447
+ dist = min_distance_between_boxes(box1, box2)
448
+ min_dist = min(min_dist, dist)
 
 
 
 
449
 
450
+ # Print the minimum distance between boxes
451
+ # print("Minimum distance between boxes:", min_dist)
452
 
453
  text_pred[0] = rescale(scale, text_pred[0])
454
  task_boxes = [box for i, box in enumerate(boxes) if full_pred['labels'][i] == list(class_dict.values()).index('task')]
455
  grouped_sentences, sentence_bounding_boxes, info_texts, info_boxes = group_texts(task_boxes, text_pred[0], text_pred[1], min_dist=min_dist)
456
  BPMN_id = set(full_pred['BPMN_id']) # This ensures uniqueness of task names
457
  text_mapping = {id: '' for id in BPMN_id}
 
458
 
459
  if print_sentences:
460
  for sentence, box in zip(grouped_sentences, sentence_bounding_boxes):
 
466
  # Map the grouped sentences to the corresponding task
467
  for i in range(len(sentence_bounding_boxes)):
468
  for j in range(len(boxes)):
469
+ if proportion_inside(sentence_bounding_boxes[i], boxes[j]) > iou_threshold and full_pred['labels'][j] == list(class_dict.values()).index('task'):
470
+ text_mapping[full_pred['BPMN_id'][j]] = grouped_sentences[i]
471
 
472
  # Map the grouped sentences to the corresponding pool
473
  for key, elements in full_pred['pool_dict'].items():
 
475
  continue
476
  else:
477
  for i in range(len(info_boxes)):
478
+ # Find the position of the key in BPMN_id
479
  position = list(full_pred['BPMN_id']).index(key)
480
+ if proportion_inside(info_boxes[i], boxes[position]) > iou_threshold:
481
  text_mapping[key] = info_texts[i]
482
  info_texts[i] = '' # Clear the text to avoid re-use
483
 
 
484
  for i in range(len(info_boxes)):
485
  if is_vertical(info_boxes[i]):
486
  for j in range(len(boxes)):
487
+ if proportion_inside(info_boxes[i], boxes[j]) > 0 and full_pred['labels'][j] == list(class_dict.values()).index('pool'):
488
  print("Text:", info_texts[i], "associate with ", full_pred['BPMN_id'][j])
489
  bpmn_id = full_pred['BPMN_id'][j]
490
  # Append new text or create new entry if not existing
 
501
  for j in range(len(boxes)):
502
  if info_texts[i] == '':
503
  continue # Skip if there's no text
504
+ if (proportion_inside(info_boxes[i], boxes[j]) > 0 or are_close(info_boxes[i], boxes[j], threshold=percentage_thresh * min_dist)) and (full_pred['labels'][j] == list(class_dict.values()).index('event')
505
  or full_pred['labels'][j] == list(class_dict.values()).index('messageEvent')
506
  or full_pred['labels'][j] == list(class_dict.values()).index('timerEvent')
507
+ or full_pred['labels'][j] == list(class_dict.values()).index('dataObject')):
508
  bpmn_id = full_pred['BPMN_id'][j]
509
  # Append new text or create new entry if not existing
510
  if bpmn_id in text_mapping:
 
518
  if info_texts[i] == '' or is_vertical(info_boxes[i]):
519
  continue # Skip if there's no text
520
  # Find the closest box within the defined threshold
521
+ closest_index = find_closest_box(info_boxes[i], boxes, full_pred['labels'], threshold=4 * min_dist)
522
  if closest_index is not None and (full_pred['labels'][closest_index] == list(class_dict.values()).index('sequenceFlow') or full_pred['labels'][closest_index] == list(class_dict.values()).index('messageFlow')):
523
  bpmn_id = full_pred['BPMN_id'][closest_index]
524
  # Append new text or create new entry if not existing
 
532
  print("Text Mapping:", text_mapping)
533
  print("Information Texts left:", info_texts)
534
 
535
+ return text_mapping
modules/dataset_loader.py CHANGED
@@ -1,7 +1,3 @@
1
- from torchvision.models.detection import keypointrcnn_resnet50_fpn
2
- from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
3
- from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
4
- from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
5
  import random
6
  import torch
7
  from torch.utils.data import Dataset
@@ -9,43 +5,60 @@ import torchvision.transforms.functional as F
9
  import numpy as np
10
  from torch.utils.data.dataloader import default_collate
11
  import cv2
12
- import matplotlib.pyplot as plt
13
- from torch.utils.data import DataLoader, Subset, ConcatDataset
14
- import streamlit as st
15
  from modules.utils import object_dict, arrow_dict, resize_boxes, resize_keypoints
 
 
16
 
17
  class RandomCrop:
18
- def __init__(self, new_size=(1333,800),crop_fraction=0.5, min_objects=4):
 
 
 
 
 
 
 
 
19
  self.crop_fraction = crop_fraction
20
  self.min_objects = min_objects
21
  self.new_size = new_size
22
 
23
  def __call__(self, image, target):
 
 
 
 
 
 
 
 
 
 
24
  new_w1, new_h1 = self.new_size
25
  w, h = image.size
26
  new_w = int(w * self.crop_fraction)
27
- new_h = int(new_w*new_h1/new_w1)
28
-
29
- i=0
30
- for i in range(4):
31
- if new_h >= h:
32
- i += 0.05
33
- new_w = int(w * (self.crop_fraction - i))
34
- new_h = int(new_w*new_h1/new_w1)
35
- if new_h < h:
36
- continue
37
 
38
- if new_h >= h:
39
- return image, target
40
 
41
  boxes = target["boxes"]
42
  if 'keypoints' in target:
43
  keypoints = target["keypoints"]
44
  else:
45
  keypoints = []
46
- for i in range(len(boxes)):
47
- keypoints.append(torch.zeros((2,3)))
48
-
49
 
50
  # Attempt to find a suitable crop region
51
  success = False
@@ -82,7 +95,7 @@ class RandomCrop:
82
  class RandomFlip:
83
  def __init__(self, h_flip_prob=0.5, v_flip_prob=0.5):
84
  """
85
- Initializes the RandomFlip with probabilities for flipping.
86
 
87
  Parameters:
88
  - h_flip_prob (float): Probability of applying a horizontal flip to the image.
@@ -93,7 +106,7 @@ class RandomFlip:
93
 
94
  def __call__(self, image, target):
95
  """
96
- Applies random horizontal and/or vertical flip to the image and updates target data accordingly.
97
 
98
  Parameters:
99
  - image (PIL Image): The image to be flipped.
@@ -143,12 +156,12 @@ class RandomFlip:
143
  target['keypoints'] = torch.stack(new_keypoints)
144
 
145
  return image, target
146
-
147
 
148
  class RandomRotate:
149
  def __init__(self, max_rotate_deg=20, rotate_proba=0.3):
150
  """
151
- Initializes the RandomRotate with a maximum rotation angle and probability of rotating.
152
 
153
  Parameters:
154
  - max_rotate_deg (int): Maximum degree to rotate the image.
@@ -159,7 +172,7 @@ class RandomRotate:
159
 
160
  def __call__(self, image, target):
161
  """
162
- Randomly rotates the image and updates the target data accordingly.
163
 
164
  Parameters:
165
  - image (PIL Image): The image to be rotated.
@@ -170,7 +183,7 @@ class RandomRotate:
170
  """
171
  if random.random() < self.rotate_proba:
172
  angle = random.uniform(-self.max_rotate_deg, self.max_rotate_deg)
173
- image = F.rotate(image, angle, expand=False, fill=200)
174
 
175
  # Rotate bounding boxes
176
  w, h = image.size
@@ -194,7 +207,16 @@ class RandomRotate:
194
 
195
  def rotate_box(self, box, angle, cx, cy):
196
  """
197
- Rotates a bounding box by a given angle around the center of the image.
 
 
 
 
 
 
 
 
 
198
  """
199
  x1, y1, x2, y2 = box
200
  corners = torch.tensor([
@@ -214,7 +236,16 @@ class RandomRotate:
214
 
215
  def rotate_keypoints(self, keypoints, angle, cx, cy):
216
  """
217
- Rotates keypoints by a given angle around the center of the image.
 
 
 
 
 
 
 
 
 
218
  """
219
  new_keypoints = []
220
  for kp in keypoints:
@@ -226,50 +257,89 @@ class RandomRotate:
226
  return torch.stack(new_keypoints)
227
 
228
  def rotate_90_box(box, angle, w, h):
 
 
 
 
 
 
 
 
 
 
 
 
229
  x1, y1, x2, y2 = box
230
  if angle == 90:
231
- return torch.tensor([y1,h-x2,y2,h-x1])
232
  elif angle == 270 or angle == -90:
233
- return torch.tensor([w-y2,x1,w-y1,x2])
234
  else:
235
  print("angle not supported")
236
 
237
  def rotate_90_keypoints(kp, angle, w, h):
 
 
 
 
 
 
 
 
 
 
 
 
238
  # Extract coordinates and visibility from each keypoint tensor
239
  x1, y1, v1 = kp[0][0], kp[0][1], kp[0][2]
240
  x2, y2, v2 = kp[1][0], kp[1][1], kp[1][2]
241
  # Swap x and y coordinates for each keypoint
242
  if angle == 90:
243
- new = [[y1, h-x1, v1], [y2, h-x2, v2]]
244
  elif angle == 270 or angle == -90:
245
- new = [[w-y1, x1, v1], [w-y2, x2, v2]]
246
 
247
  return torch.tensor(new, dtype=torch.float32)
248
-
249
 
250
  def rotate_vertical(image, target):
251
- # Rotate the image and target if the image is vertical
 
 
 
 
 
 
 
 
 
252
  new_boxes = []
253
- angle = random.choice([-90,90])
254
  image = F.rotate(image, angle, expand=True, fill=200)
255
  for box in target["boxes"]:
256
  new_box = rotate_90_box(box, angle, image.size[0], image.size[1])
257
  new_boxes.append(new_box)
258
  target["boxes"] = torch.stack(new_boxes)
259
-
260
  if 'keypoints' in target:
261
- new_kp = []
262
- for kp in target['keypoints']:
263
  new_key = rotate_90_keypoints(kp, angle, image.size[0], image.size[1])
264
  new_kp.append(new_key)
265
  target['keypoints'] = torch.stack(new_kp)
266
  return image, target
267
 
 
 
 
268
 
269
- import torchvision.transforms.functional as F
270
- import torch
 
 
271
 
272
- def resize_and_pad(image, target, new_size=(1333, 800)):
 
 
273
  original_size = image.size
274
  # Calculate scale to fit the new size while maintaining aspect ratio
275
  scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
@@ -302,8 +372,24 @@ def resize_and_pad(image, target, new_size=(1333, 800)):
302
  return image, target
303
 
304
  class BPMN_Dataset(Dataset):
305
- def __init__(self, annotations, transform=None, crop_transform=None, crop_prob=0.3, rotate_90_proba=0.2,
306
- flip_transform=None, rotate_transform=None, new_size=(1333,1333), keep_ratio=0.1, resize=True, model_type='object'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  self.annotations = annotations
308
  print(f"Loaded {len(self.annotations)} annotations.")
309
  self.transform = transform
@@ -322,15 +408,30 @@ class BPMN_Dataset(Dataset):
322
  self.rotate_90_proba = rotate_90_proba
323
 
324
  def __len__(self):
 
 
 
 
 
 
325
  return len(self.annotations)
326
 
327
  def __getitem__(self, idx):
 
 
 
 
 
 
 
 
 
328
  annotation = self.annotations[idx]
329
  image = annotation.img.convert("RGB")
330
  boxes = torch.tensor(np.array(annotation.boxes_ltrb), dtype=torch.float32)
331
  labels_names = [ann for ann in annotation.categories]
332
 
333
- # Only keep the labels, boxes and keypoints that are in the class_dict
334
  kept_indices = [i for i, ann in enumerate(annotation.categories) if ann in self.dict.values()]
335
  boxes = boxes[kept_indices]
336
  labels_names = [ann for i, ann in enumerate(labels_names) if i in kept_indices]
@@ -351,7 +452,7 @@ class BPMN_Dataset(Dataset):
351
  if ann.category in ["sequenceFlow", "messageFlow", "dataAssociation"]:
352
  # Fill the keypoints tensor for this annotation, mark as visible (1)
353
  kp = np.array(ann.keypoints, dtype=np.float32).reshape(-1, 3)
354
- kp = kp[:,:2]
355
  visible = np.ones((kp.shape[0], 1), dtype=np.float32)
356
  kp = np.hstack([kp, visible])
357
  keypoints[ii, :kp.shape[0], :] = torch.tensor(kp, dtype=torch.float32)
@@ -359,17 +460,17 @@ class BPMN_Dataset(Dataset):
359
 
360
  area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
361
 
362
- if self.model_type == 'object':
363
  target = {
364
  "boxes": boxes,
365
  "labels": labels_id,
366
- #"area": area,
367
  }
368
  elif self.model_type == 'arrow':
369
  target = {
370
  "boxes": boxes,
371
  "labels": labels_id,
372
- #"area": area,
373
  "keypoints": keypoints,
374
  }
375
 
@@ -384,7 +485,7 @@ class BPMN_Dataset(Dataset):
384
  # Randomly apply the custom cropping transform
385
  if self.crop_transform and random.random() < self.crop_prob:
386
  image, target = self.crop_transform(image, target)
387
-
388
  # Rotate vertical image
389
  if random.random() < self.rotate_90_proba:
390
  image, target = rotate_vertical(image, target)
@@ -394,12 +495,12 @@ class BPMN_Dataset(Dataset):
394
  # Center and pad the image while keeping the aspect ratio
395
  image, target = resize_and_pad(image, target, self.new_size)
396
  else:
397
- target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), self.new_size)
398
  if 'area' in target:
399
  target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
400
  if 'keypoints' in target:
401
  for i in range(len(target['keypoints'])):
402
- target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), self.new_size)
403
  image = F.resize(image, (self.new_size[1], self.new_size[0]))
404
 
405
  return self.transform(image), target
@@ -429,15 +530,15 @@ def collate_fn(batch):
429
  return images, targets
430
 
431
 
432
-
433
- def create_loader(new_size,transformation, annotations1, annotations2=None,
434
  batch_size=4, crop_prob=0.2, crop_fraction=0.7, min_objects=3,
435
  h_flip_prob=0.3, v_flip_prob=0.3, max_rotate_deg=20, rotate_90_proba=0.2, rotate_proba=0.3,
436
- seed=42, resize=True, keep_ratio=0.1, model_type = 'object'):
437
  """
438
- Creates a DataLoader for BPMN datasets with optional transformations and concatenation of two datasets.
439
 
440
  Parameters:
 
441
  - transformation (callable): Transformation function to apply to each image (e.g., normalization).
442
  - annotations1 (list): Primary list of annotations.
443
  - annotations2 (list, optional): Secondary list of annotations to concatenate with the first.
@@ -447,15 +548,20 @@ def create_loader(new_size,transformation, annotations1, annotations2=None,
447
  - min_objects (int): Minimum number of objects required to be within the crop.
448
  - h_flip_prob (float): Probability of applying horizontal flip.
449
  - v_flip_prob (float): Probability of applying vertical flip.
 
 
 
450
  - seed (int): Seed for random number generators for reproducibility.
451
  - resize (bool): Flag indicating whether to resize images after transformations.
 
 
452
 
453
  Returns:
454
  - DataLoader: Configured data loader for the dataset.
455
  """
456
 
457
  # Initialize custom transformations for cropping and flipping
458
- custom_crop_transform = RandomCrop(new_size,crop_fraction, min_objects)
459
  custom_flip_transform = RandomFlip(h_flip_prob, v_flip_prob)
460
  custom_rotate_transform = RandomRotate(max_rotate_deg, rotate_proba)
461
 
@@ -497,4 +603,4 @@ def create_loader(new_size,transformation, annotations1, annotations2=None,
497
  # Create the DataLoader with the dataset
498
  data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
499
 
500
- return data_loader
 
 
 
 
 
1
  import random
2
  import torch
3
  from torch.utils.data import Dataset
 
5
  import numpy as np
6
  from torch.utils.data.dataloader import default_collate
7
  import cv2
8
+ from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
 
 
9
  from modules.utils import object_dict, arrow_dict, resize_boxes, resize_keypoints
10
+ import torchvision.transforms.functional as F
11
+ import torch
12
 
13
  class RandomCrop:
14
+ def __init__(self, new_size=(1333, 800), crop_fraction=0.5, min_objects=4):
15
+ """
16
+ Initialize the RandomCrop transformation.
17
+
18
+ Parameters:
19
+ - new_size (tuple): The target size for the image after cropping.
20
+ - crop_fraction (float): The fraction of the original width to use when cropping.
21
+ - min_objects (int): Minimum number of objects required to be within the crop.
22
+ """
23
  self.crop_fraction = crop_fraction
24
  self.min_objects = min_objects
25
  self.new_size = new_size
26
 
27
  def __call__(self, image, target):
28
+ """
29
+ Apply the RandomCrop transformation to the image and its target.
30
+
31
+ Parameters:
32
+ - image (PIL Image): The image to be cropped.
33
+ - target (dict): The target dictionary containing 'boxes' and optional 'keypoints'.
34
+
35
+ Returns:
36
+ - PIL Image, dict: The cropped image and its updated target dictionary.
37
+ """
38
  new_w1, new_h1 = self.new_size
39
  w, h = image.size
40
  new_w = int(w * self.crop_fraction)
41
+ new_h = int(new_w * new_h1 / new_w1)
42
+
43
+ i = 0
44
+ for i in range(4): # Try 4 times to adjust new_w and new_h if new_h >= h
45
+ if new_h >= h:
46
+ i += 0.05
47
+ new_w = int(w * (self.crop_fraction - i))
48
+ new_h = int(new_w * new_h1 / new_w1)
49
+ if new_h < h:
50
+ continue
51
 
52
+ if new_h >= h: # If still not valid, return original image and target
53
+ return image, target
54
 
55
  boxes = target["boxes"]
56
  if 'keypoints' in target:
57
  keypoints = target["keypoints"]
58
  else:
59
  keypoints = []
60
+ for _ in range(len(boxes)):
61
+ keypoints.append(torch.zeros((2, 3)))
 
62
 
63
  # Attempt to find a suitable crop region
64
  success = False
 
95
  class RandomFlip:
96
  def __init__(self, h_flip_prob=0.5, v_flip_prob=0.5):
97
  """
98
+ Initialize the RandomFlip transformation with probabilities for flipping.
99
 
100
  Parameters:
101
  - h_flip_prob (float): Probability of applying a horizontal flip to the image.
 
106
 
107
  def __call__(self, image, target):
108
  """
109
+ Apply random horizontal and/or vertical flip to the image and updates target data accordingly.
110
 
111
  Parameters:
112
  - image (PIL Image): The image to be flipped.
 
156
  target['keypoints'] = torch.stack(new_keypoints)
157
 
158
  return image, target
159
+
160
 
161
  class RandomRotate:
162
  def __init__(self, max_rotate_deg=20, rotate_proba=0.3):
163
  """
164
+ Initialize the RandomRotate transformation with a maximum rotation angle and probability of rotating.
165
 
166
  Parameters:
167
  - max_rotate_deg (int): Maximum degree to rotate the image.
 
172
 
173
  def __call__(self, image, target):
174
  """
175
+ Randomly rotate the image and updates the target data accordingly.
176
 
177
  Parameters:
178
  - image (PIL Image): The image to be rotated.
 
183
  """
184
  if random.random() < self.rotate_proba:
185
  angle = random.uniform(-self.max_rotate_deg, self.max_rotate_deg)
186
+ image = F.rotate(image, angle, expand=False, fill=255)
187
 
188
  # Rotate bounding boxes
189
  w, h = image.size
 
207
 
208
  def rotate_box(self, box, angle, cx, cy):
209
  """
210
+ Rotate a bounding box by a given angle around the center of the image.
211
+
212
+ Parameters:
213
+ - box (tensor): The bounding box to be rotated.
214
+ - angle (float): The angle to rotate the box.
215
+ - cx (float): The x-coordinate of the image center.
216
+ - cy (float): The y-coordinate of the image center.
217
+
218
+ Returns:
219
+ - tensor: The rotated bounding box.
220
  """
221
  x1, y1, x2, y2 = box
222
  corners = torch.tensor([
 
236
 
237
  def rotate_keypoints(self, keypoints, angle, cx, cy):
238
  """
239
+ Rotate keypoints by a given angle around the center of the image.
240
+
241
+ Parameters:
242
+ - keypoints (tensor): The keypoints to be rotated.
243
+ - angle (float): The angle to rotate the keypoints.
244
+ - cx (float): The x-coordinate of the image center.
245
+ - cy (float): The y-coordinate of the image center.
246
+
247
+ Returns:
248
+ - tensor: The rotated keypoints.
249
  """
250
  new_keypoints = []
251
  for kp in keypoints:
 
257
  return torch.stack(new_keypoints)
258
 
259
  def rotate_90_box(box, angle, w, h):
260
+ """
261
+ Rotate a bounding box by 90 degrees.
262
+
263
+ Parameters:
264
+ - box (tensor): The bounding box to be rotated.
265
+ - angle (int): The angle to rotate the box (90 or -90 degrees).
266
+ - w (int): The width of the image.
267
+ - h (int): The height of the image.
268
+
269
+ Returns:
270
+ - tensor: The rotated bounding box.
271
+ """
272
  x1, y1, x2, y2 = box
273
  if angle == 90:
274
+ return torch.tensor([y1, h - x2, y2, h - x1])
275
  elif angle == 270 or angle == -90:
276
+ return torch.tensor([w - y2, x1, w - y1, x2])
277
  else:
278
  print("angle not supported")
279
 
280
  def rotate_90_keypoints(kp, angle, w, h):
281
+ """
282
+ Rotate keypoints by 90 degrees.
283
+
284
+ Parameters:
285
+ - kp (tensor): The keypoints to be rotated.
286
+ - angle (int): The angle to rotate the keypoints (90 or -90 degrees).
287
+ - w (int): The width of the image.
288
+ - h (int): The height of the image.
289
+
290
+ Returns:
291
+ - tensor: The rotated keypoints.
292
+ """
293
  # Extract coordinates and visibility from each keypoint tensor
294
  x1, y1, v1 = kp[0][0], kp[0][1], kp[0][2]
295
  x2, y2, v2 = kp[1][0], kp[1][1], kp[1][2]
296
  # Swap x and y coordinates for each keypoint
297
  if angle == 90:
298
+ new = [[y1, h - x1, v1], [y2, h - x2, v2]]
299
  elif angle == 270 or angle == -90:
300
+ new = [[w - y1, x1, v1], [w - y2, x2, v2]]
301
 
302
  return torch.tensor(new, dtype=torch.float32)
 
303
 
304
  def rotate_vertical(image, target):
305
+ """
306
+ Rotate the image and target if the image is vertical.
307
+
308
+ Parameters:
309
+ - image (PIL Image): The image to be rotated.
310
+ - target (dict): The target dictionary containing 'boxes' and 'keypoints'.
311
+
312
+ Returns:
313
+ - PIL Image, dict: The rotated image and its updated target dictionary.
314
+ """
315
  new_boxes = []
316
+ angle = random.choice([-90, 90])
317
  image = F.rotate(image, angle, expand=True, fill=200)
318
  for box in target["boxes"]:
319
  new_box = rotate_90_box(box, angle, image.size[0], image.size[1])
320
  new_boxes.append(new_box)
321
  target["boxes"] = torch.stack(new_boxes)
322
+
323
  if 'keypoints' in target:
324
+ new_kp = []
325
+ for kp in target['keypoints']:
326
  new_key = rotate_90_keypoints(kp, angle, image.size[0], image.size[1])
327
  new_kp.append(new_key)
328
  target['keypoints'] = torch.stack(new_kp)
329
  return image, target
330
 
331
+ def resize_and_pad(image, target, new_size=(1333, 800)):
332
+ """
333
+ Resize and pad the image and target to the specified new size while maintaining the aspect ratio.
334
 
335
+ Parameters:
336
+ - image (PIL Image): The image to be resized and padded.
337
+ - target (dict): The target dictionary containing 'boxes' and optional 'keypoints'.
338
+ - new_size (tuple): The target size for the image after resizing and padding.
339
 
340
+ Returns:
341
+ - PIL Image, dict: The resized and padded image and its updated target dictionary.
342
+ """
343
  original_size = image.size
344
  # Calculate scale to fit the new size while maintaining aspect ratio
345
  scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
 
372
  return image, target
373
 
374
  class BPMN_Dataset(Dataset):
375
+ def __init__(self, annotations, transform=None, crop_transform=None, crop_prob=0.3, rotate_90_proba=0.2,
376
+ flip_transform=None, rotate_transform=None, new_size=(1333, 1333), keep_ratio=0.1, resize=True, model_type='object'):
377
+ """
378
+ Initialize the BPMN_Dataset with annotations and optional transformations.
379
+
380
+ Parameters:
381
+ - annotations (list): List of annotations for the dataset.
382
+ - transform (callable, optional): Transformation function to apply to each image.
383
+ - crop_transform (callable, optional): Custom cropping transformation.
384
+ - crop_prob (float): Probability of applying the crop transformation.
385
+ - rotate_90_proba (float): Probability of rotating the image by 90 degrees.
386
+ - flip_transform (callable, optional): Custom flipping transformation.
387
+ - rotate_transform (callable, optional): Custom rotation transformation.
388
+ - new_size (tuple): Target size for the images.
389
+ - keep_ratio (float): Probability of keeping the aspect ratio during resizing.
390
+ - resize (bool): Flag indicating whether to resize images after transformations.
391
+ - model_type (str): Type of model ('object' or 'arrow') to determine the target dictionary.
392
+ """
393
  self.annotations = annotations
394
  print(f"Loaded {len(self.annotations)} annotations.")
395
  self.transform = transform
 
408
  self.rotate_90_proba = rotate_90_proba
409
 
410
  def __len__(self):
411
+ """
412
+ Return the number of annotations in the dataset.
413
+
414
+ Returns:
415
+ - int: The number of annotations.
416
+ """
417
  return len(self.annotations)
418
 
419
  def __getitem__(self, idx):
420
+ """
421
+ Get an item (image and target) from the dataset at the specified index.
422
+
423
+ Parameters:
424
+ - idx (int): The index of the item to retrieve.
425
+
426
+ Returns:
427
+ - PIL Image, dict: The transformed image and its updated target dictionary.
428
+ """
429
  annotation = self.annotations[idx]
430
  image = annotation.img.convert("RGB")
431
  boxes = torch.tensor(np.array(annotation.boxes_ltrb), dtype=torch.float32)
432
  labels_names = [ann for ann in annotation.categories]
433
 
434
+ # Only keep the labels, boxes, and keypoints that are in the class_dict
435
  kept_indices = [i for i, ann in enumerate(annotation.categories) if ann in self.dict.values()]
436
  boxes = boxes[kept_indices]
437
  labels_names = [ann for i, ann in enumerate(labels_names) if i in kept_indices]
 
452
  if ann.category in ["sequenceFlow", "messageFlow", "dataAssociation"]:
453
  # Fill the keypoints tensor for this annotation, mark as visible (1)
454
  kp = np.array(ann.keypoints, dtype=np.float32).reshape(-1, 3)
455
+ kp = kp[:, :2]
456
  visible = np.ones((kp.shape[0], 1), dtype=np.float32)
457
  kp = np.hstack([kp, visible])
458
  keypoints[ii, :kp.shape[0], :] = torch.tensor(kp, dtype=torch.float32)
 
460
 
461
  area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
462
 
463
+ if self.model_type == 'object':
464
  target = {
465
  "boxes": boxes,
466
  "labels": labels_id,
467
+ # "area": area,
468
  }
469
  elif self.model_type == 'arrow':
470
  target = {
471
  "boxes": boxes,
472
  "labels": labels_id,
473
+ # "area": area,
474
  "keypoints": keypoints,
475
  }
476
 
 
485
  # Randomly apply the custom cropping transform
486
  if self.crop_transform and random.random() < self.crop_prob:
487
  image, target = self.crop_transform(image, target)
488
+
489
  # Rotate vertical image
490
  if random.random() < self.rotate_90_proba:
491
  image, target = rotate_vertical(image, target)
 
495
  # Center and pad the image while keeping the aspect ratio
496
  image, target = resize_and_pad(image, target, self.new_size)
497
  else:
498
+ target['boxes'] = resize_boxes(target['boxes'], (image.size[0], image.size[1]), self.new_size)
499
  if 'area' in target:
500
  target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
501
  if 'keypoints' in target:
502
  for i in range(len(target['keypoints'])):
503
+ target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0], image.size[1]), self.new_size)
504
  image = F.resize(image, (self.new_size[1], self.new_size[0]))
505
 
506
  return self.transform(image), target
 
530
  return images, targets
531
 
532
 
533
+ def create_loader(new_size, transformation, annotations1, annotations2=None,
 
534
  batch_size=4, crop_prob=0.2, crop_fraction=0.7, min_objects=3,
535
  h_flip_prob=0.3, v_flip_prob=0.3, max_rotate_deg=20, rotate_90_proba=0.2, rotate_proba=0.3,
536
+ seed=42, resize=True, keep_ratio=0.1, model_type='object'):
537
  """
538
+ Create a DataLoader for BPMN datasets with optional transformations and concatenation of two datasets.
539
 
540
  Parameters:
541
+ - new_size (tuple): The target size for the images.
542
  - transformation (callable): Transformation function to apply to each image (e.g., normalization).
543
  - annotations1 (list): Primary list of annotations.
544
  - annotations2 (list, optional): Secondary list of annotations to concatenate with the first.
 
548
  - min_objects (int): Minimum number of objects required to be within the crop.
549
  - h_flip_prob (float): Probability of applying horizontal flip.
550
  - v_flip_prob (float): Probability of applying vertical flip.
551
+ - max_rotate_deg (int): Maximum degree to rotate the image.
552
+ - rotate_90_proba (float): Probability of rotating the image by 90 degrees.
553
+ - rotate_proba (float): Probability of applying rotation to the image.
554
  - seed (int): Seed for random number generators for reproducibility.
555
  - resize (bool): Flag indicating whether to resize images after transformations.
556
+ - keep_ratio (float): Probability of keeping the aspect ratio during resizing.
557
+ - model_type (str): Type of model ('object' or 'arrow') to determine the target dictionary.
558
 
559
  Returns:
560
  - DataLoader: Configured data loader for the dataset.
561
  """
562
 
563
  # Initialize custom transformations for cropping and flipping
564
+ custom_crop_transform = RandomCrop(new_size, crop_fraction, min_objects)
565
  custom_flip_transform = RandomFlip(h_flip_prob, v_flip_prob)
566
  custom_rotate_transform = RandomRotate(max_rotate_deg, rotate_proba)
567
 
 
603
  # Create the DataLoader with the dataset
604
  data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
605
 
606
+ return data_loader
modules/eval.py CHANGED
@@ -9,6 +9,18 @@ from builtins import dict
9
 
10
 
11
  def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
 
 
 
 
 
 
 
 
 
 
 
 
12
  exception = ['pool', 'lane']
13
 
14
  idxs = np.argsort(scores) # Sort the boxes according to their scores in ascending order
@@ -40,6 +52,19 @@ def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
40
 
41
 
42
  def keypoint_correction(keypoints, boxes, labels, model_dict=arrow_dict, distance_treshold=15):
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  for idx, (key1, key2) in enumerate(keypoints):
44
  if labels[idx] not in [list(model_dict.values()).index('sequenceFlow'),
45
  list(model_dict.values()).index('messageFlow'),
@@ -49,14 +74,26 @@ def keypoint_correction(keypoints, boxes, labels, model_dict=arrow_dict, distanc
49
  distance = np.linalg.norm(key1[:2] - key2[:2])
50
  if distance < distance_treshold:
51
  print('Key modified for index:', idx)
52
- x_new,y_new, x,y = find_other_keypoint(idx, keypoints, boxes)
53
- keypoints[idx][0][:2] = [x_new,y_new]
54
- keypoints[idx][1][:2] = [x,y]
55
 
56
  return keypoints
57
 
58
 
59
  def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
 
 
 
 
 
 
 
 
 
 
 
 
60
  model.eval()
61
  with torch.no_grad():
62
  image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
@@ -73,7 +110,7 @@ def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
73
 
74
  selected_boxes = non_maximum_suppression(boxes, scores, labels=labels, iou_threshold=iou_threshold)
75
 
76
- #find orientation of the task by checking the size of all the boxes and delete the one that are not in the same orientation
77
  vertical = 0
78
  for i in range(len(labels)):
79
  if labels[i] != list(object_dict.values()).index('task'):
@@ -87,12 +124,12 @@ def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
87
 
88
  if vertical < horizontal:
89
  if is_vertical(boxes[i]):
90
- #find the element in the list and remove it
91
  if i in selected_boxes:
92
  selected_boxes.remove(i)
93
  elif vertical > horizontal:
94
  if is_vertical(boxes[i]) == False:
95
- #find the element in the list and remove it
96
  if i in selected_boxes:
97
  selected_boxes.remove(i)
98
  else:
@@ -102,23 +139,21 @@ def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
102
  scores = scores[selected_boxes]
103
  labels = labels[selected_boxes]
104
 
105
- #find the outlier object that are too small by the area
106
- obj_not_too_small = find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, element_ref = ['event', 'messageEvent'], mode = "lower")
107
- obj_not_too_big = find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=2, element_ref = ['task'], mode = "upper")
108
 
109
  selected_object = [i for i in range(len(labels)) if i in obj_not_too_small and i in obj_not_too_big]
110
 
111
- #selected_object = obj_not_too_small
112
-
113
  boxes = boxes[selected_object]
114
  scores = scores[selected_object]
115
  labels = labels[selected_object]
116
 
117
- #modify the label of the sub-process to task
118
  for i in range(len(labels)):
119
  if labels[i] == list(object_dict.values()).index('subProcess'):
120
  labels[i] = list(object_dict.values()).index('task')
121
- #delete all lane and also the value in the labels and scores
122
  lane_index = [i for i in range(len(labels)) if labels[i] == list(object_dict.values()).index('lane')]
123
  boxes = np.delete(boxes, lane_index, axis=0)
124
  labels = np.delete(labels, lane_index)
@@ -137,6 +172,19 @@ def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
137
 
138
 
139
  def arrow_prediction(model, image, score_threshold=0.5, iou_threshold=0.5, distance_treshold=15):
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  model.eval()
141
  with torch.no_grad():
142
  image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
@@ -173,7 +221,18 @@ def arrow_prediction(model, image, score_threshold=0.5, iou_threshold=0.5, dista
173
 
174
  return image, prediction
175
 
 
176
  def mix_predictions(objects_pred, arrow_pred):
 
 
 
 
 
 
 
 
 
 
177
  # Initialize the list of lists for keypoints
178
  object_keypoints = []
179
 
@@ -186,7 +245,7 @@ def mix_predictions(objects_pred, arrow_pred):
186
  keypoints = [[0, 0, 0], [0, 0, 0]]
187
  object_keypoints.append(keypoints)
188
 
189
- #concatenate the two predictions
190
  if len(arrow_pred['boxes']) == 0:
191
  return objects_pred['boxes'], objects_pred['labels'], objects_pred['scores'], object_keypoints
192
 
@@ -199,6 +258,21 @@ def mix_predictions(objects_pred, arrow_pred):
199
 
200
 
201
  def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_threshold=0.6):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  pool_dict = {}
203
 
204
  # Filter out pools with IoU greater than the threshold
@@ -265,12 +339,24 @@ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_t
265
  return pool_dict, boxes, labels, scores, keypoints
266
 
267
 
268
-
269
  def create_links(keypoints, boxes, labels, class_dict):
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  best_points = []
271
  links = []
272
  for i in range(len(labels)):
273
- if labels[i]==list(class_dict.values()).index('sequenceFlow') or labels[i]==list(class_dict.values()).index('messageFlow'):
274
  closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
275
  closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
276
 
@@ -278,11 +364,11 @@ def create_links(keypoints, boxes, labels, class_dict):
278
  best_points.append([point_start, point_end])
279
  links.append([closest1, closest2])
280
  else:
281
- best_points.append([None,None])
282
- links.append([None,None])
283
 
284
  for i in range(len(labels)):
285
- if labels[i]==list(class_dict.values()).index('dataAssociation'):
286
  closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
287
  closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
288
  if closest1 is not None and closest2 is not None:
@@ -291,7 +377,22 @@ def create_links(keypoints, boxes, labels, class_dict):
291
 
292
  return links, best_points
293
 
 
294
  def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  sequence_flow_index = list(class_dict.values()).index('sequenceFlow')
296
  message_flow_index = list(class_dict.values()).index('messageFlow')
297
  data_association_index = list(class_dict.values()).index('dataAssociation')
@@ -339,7 +440,21 @@ def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
339
  return labels, flow_links
340
 
341
 
342
- def find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, element_ref = ['event', 'messageEvent'], mode = "lower"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  # Filter out the sizes of events, data objects, and message events
344
  event_indices = [i for i, label in enumerate(labels) if class_dict[label] in element_ref]
345
  event_boxes = [boxes[i] for i in event_indices]
@@ -360,7 +475,7 @@ def find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, elem
360
  kept_indices = []
361
 
362
  if mode == "lower" or mode == 'both':
363
- #check for object that could be too small
364
  for idx, (box, label) in enumerate(zip(boxes, labels)):
365
  area = (box[2] - box[0]) * (box[3] - box[1])
366
  if not (area_lower_threshold <= area):
@@ -370,7 +485,7 @@ def find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, elem
370
  kept_indices.append(idx)
371
 
372
  if mode == "upper" or mode == 'both':
373
- #check for object that could be too big
374
  for idx, (box, label) in enumerate(zip(boxes, labels)):
375
  if label == list(class_dict.values()).index('pool') or label == list(class_dict.values()).index('lane'):
376
  kept_indices.append(idx)
@@ -382,17 +497,31 @@ def find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, elem
382
  else:
383
  kept_indices.append(idx)
384
 
385
-
386
  return kept_indices
387
 
388
 
389
-
390
  def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict, limit_area=10000):
391
-
392
- #delete pool that are have only messageFlow on it
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  delete_pool = []
394
  for pool_index, elements in pool_dict.items():
395
- #find the position of the pool_index in the bpmn_id
396
  if pool_index in bpmn_id:
397
  position = bpmn_id.index(pool_index)
398
  else:
@@ -405,11 +534,11 @@ def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_point
405
  delete_pool.append(position)
406
  print(f"Pool {pool_index} contains only arrow elements, deleting it")
407
 
408
- #calcul the area of the pool$
409
  if position < len(boxes):
410
  pool = boxes[position]
411
  area = (pool[2] - pool[0]) * (pool[3] - pool[1])
412
- if len(pool_dict)>1 and area < limit_area:
413
  delete_pool.append(position)
414
  print(f"Pool {pool_index} is too small, deleting it")
415
 
@@ -417,34 +546,23 @@ def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_point
417
  delete_pool.append(position)
418
  print(f"Pool {position} is vertical, deleting it")
419
 
420
-
421
  delete_elements = []
422
  # Check if there is an arrow that has the same links
423
  for i in range(len(labels)):
424
- for j in range(i+1, len(labels)):
425
  if labels[i] == list(class_dict.values()).index('sequenceFlow') and labels[j] == list(class_dict.values()).index('sequenceFlow'):
426
  if links[i] == links[j]:
427
- print(f'element {i} and {j} have the same links')
428
  if scores[i] > scores[j]:
429
- print('delete element', j)
430
  delete_elements.append(j)
431
  else:
432
- print('delete element', i)
433
  delete_elements.append(i)
434
 
435
- #filter box that are inside a text box
436
- """tex_pred = st.session_state.text_pred
437
- for i in range(len(boxes)):
438
- for j in range(len(tex_pred[0])):
439
- #check if the box is inside the text box but if the text box is inside the box then it is not a problem
440
- if proportion_inside(boxes[i], tex_pred[0][j]) > 0.1:
441
- #delete_elements.append(i)
442
- print('delete element', i)"""
443
-
444
-
445
- #concatenate the delete_elements and the delete_pool
446
  delete_elements = delete_elements + delete_pool
447
- #delete double value in delete_elements
448
  delete_elements = list(set(delete_elements))
449
 
450
  boxes = np.delete(boxes, delete_elements, axis=0)
@@ -456,74 +574,129 @@ def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_point
456
  best_points = [point for i, point in enumerate(best_points) if i not in delete_elements]
457
 
458
  for i in range(len(delete_pool)):
459
- #find the bpmn_id of the pool
460
  pool_index = bpmn_id[delete_pool[i]]
461
- #delete the pool_index in pool_dict
462
  del pool_dict[pool_index]
463
 
464
  bpmn_id = [point for i, point in enumerate(bpmn_id) if i not in delete_elements]
465
 
466
- #also delete the element in the pool_dict
467
  for pool_index, elements in pool_dict.items():
468
  pool_dict[pool_index] = [i for i in elements if i not in delete_elements]
469
 
470
  return boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict
471
 
 
472
  def give_link_to_element(links, labels):
473
- #give a link to event to allow the creation of the BPMN id with start, indermediate and end event
474
- for i in range(len(links)):
475
- if labels[i] == list(class_dict.values()).index('sequenceFlow'):
476
- id1, id2 = links[i]
477
- if (id1 and id2) is not None:
478
- links[id1][1] = i
479
- links[id2][0] = i
480
- return links
 
 
 
 
 
 
 
 
 
 
481
 
482
 
483
  def generate_data(image, boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict):
484
- idx = []
485
- for i in range(len(labels)):
486
- idx.append(i)
487
-
488
-
489
- data = {
490
- 'image': image,
491
- 'idx': idx,
492
- 'boxes': boxes,
493
- 'labels': labels,
494
- 'scores': scores,
495
- 'keypoints': keypoints,
496
- 'links': flow_links,
497
- 'best_points': best_points,
498
- 'pool_dict': pool_dict,
499
- 'BPMN_id': bpmn_id,
500
- }
501
-
502
-
503
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
 
505
- def develop_prediction(boxes, labels, scores, keypoints, class_dict):
506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
  pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict)
508
 
509
- bpmn_id, pool_dict = create_BPMN_id(labels,pool_dict)
510
 
511
  # Create links between elements
512
  flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
513
 
514
- #Correct the labels of some sequenceflow that cross multiple pool
515
  labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
516
 
517
- #give a link to event to allow the creation of the BPMN id with start, indermediate and end event
518
  flow_links = give_link_to_element(flow_links, labels)
519
 
520
- boxes,labels,scores,keypoints,bpmn_id, flow_links,best_points,pool_dict = last_correction(boxes,labels,scores,keypoints,bpmn_id,flow_links,best_points, pool_dict)
 
 
521
 
522
- return boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict
523
 
524
-
525
 
526
  def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_threshold=0.5, resize=True, distance_treshold=15):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  model_object.eval() # Set the model to evaluation mode
528
  model_arrow.eval() # Set the model to evaluation mode
529
 
@@ -536,7 +709,9 @@ def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_t
536
 
537
  boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
538
 
539
- boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction(boxes, labels, scores, keypoints, class_dict)
 
 
540
 
541
  image = image.permute(1, 2, 0).cpu().numpy()
542
  image = (image * 255).astype(np.uint8)
@@ -545,7 +720,22 @@ def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_t
545
 
546
  return image, data
547
 
 
548
  def evaluate_model_by_class(pred_boxes, true_boxes, pred_labels, true_labels, model_dict, iou_threshold=0.5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  # Initialize dictionaries to hold per-class counts
550
  class_tp = {cls: 0 for cls in model_dict.values()}
551
  class_fp = {cls: 0 for cls in model_dict.values()}
@@ -589,10 +779,25 @@ def evaluate_model_by_class(pred_boxes, true_boxes, pred_labels, true_labels, mo
589
  return class_precision, class_recall, class_f1_score
590
 
591
 
592
- def keypoints_mesure(pred_boxes, pred_box, true_boxes, true_box, pred_keypoints, true_keypoints, distance_threshold=5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  result = 0
594
  reverted = False
595
- #find the position of keypoints in the list
596
  idx = np.where(pred_boxes == pred_box)[0][0]
597
  idx2 = np.where(true_boxes == true_box)[0][0]
598
 
@@ -615,7 +820,24 @@ def keypoints_mesure(pred_boxes, pred_box, true_boxes, true_box, pred_keypoints,
615
 
616
  return result, reverted
617
 
 
618
  def evaluate_single_image(pred_boxes, true_boxes, pred_labels, true_labels, pred_keypoints, true_keypoints, iou_threshold=0.5, distance_threshold=5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  tp, fp, fn = 0, 0, 0
620
  key_t, key_f = 0, 0
621
  labels_t, labels_f = 0, 0
@@ -630,7 +852,9 @@ def evaluate_single_image(pred_boxes, true_boxes, pred_labels, true_labels, pred
630
  iou_val = iou(pred_box, true_box)
631
  if iou_val >= iou_threshold:
632
  if true_keypoints is not None and pred_keypoints is not None:
633
- key_result, reverted = keypoints_mesure(pred_boxes, pred_box, true_boxes, true_box, pred_keypoints, true_keypoints, distance_threshold)
 
 
634
  key_t += key_result
635
  key_f += 2 - key_result
636
  if reverted:
@@ -653,6 +877,21 @@ def evaluate_single_image(pred_boxes, true_boxes, pred_labels, true_labels, pred
653
 
654
 
655
  def pred_4_evaluation(model, loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type='object'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  model.eval()
657
  tp, fp, fn = 0, 0, 0
658
  labels_t, labels_f = 0, 0
@@ -690,7 +929,7 @@ def pred_4_evaluation(model, loader, score_threshold=0.5, iou_threshold=0.5, dis
690
  filtered_labels = []
691
  filtered_keypoints = []
692
  if 'keypoints' not in prediction:
693
- #create a list of zeros of length equal to the number of boxes
694
  pred_keypoints = [np.zeros((2, 3)) for _ in range(len(pred_boxes))]
695
 
696
  for box, score, label, keypoints in zip(pred_boxes, scores, pred_labels, pred_keypoints):
@@ -707,7 +946,8 @@ def pred_4_evaluation(model, loader, score_threshold=0.5, iou_threshold=0.5, dis
707
  filtered_keypoints = None
708
  true_keypoints = None
709
  tp_img, fp_img, fn_img, labels_t_img, labels_f_img, key_t_img, key_f_img, reverted_img = evaluate_single_image(
710
- filtered_boxes, true_boxes, filtered_labels, true_labels, filtered_keypoints, true_keypoints, iou_threshold, distance_threshold)
 
711
 
712
  tp += tp_img
713
  fp += fp_img
@@ -720,9 +960,26 @@ def pred_4_evaluation(model, loader, score_threshold=0.5, iou_threshold=0.5, dis
720
 
721
  return tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted
722
 
723
- def main_evaluation(model, test_loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type = 'object'):
724
 
725
- tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted = pred_4_evaluation(model, test_loader, score_threshold, iou_threshold, distance_threshold, key_correction, model_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
 
727
  labels_precision = labels_t / (labels_t + labels_f) if (labels_t + labels_f) > 0 else 0
728
  precision = tp / (tp + fp) if (tp + fp) > 0 else 0
@@ -738,8 +995,21 @@ def main_evaluation(model, test_loader, score_threshold=0.5, iou_threshold=0.5,
738
  return labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy
739
 
740
 
741
-
742
  def evaluate_model_by_class_single_image(pred_boxes, true_boxes, pred_labels, true_labels, class_tp, class_fp, class_fn, model_dict, iou_threshold=0.5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743
  matched_true_boxes = set()
744
  for pred_idx, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)):
745
  match_found = False
@@ -758,7 +1028,20 @@ def evaluate_model_by_class_single_image(pred_boxes, true_boxes, pred_labels, tr
758
  if idx not in matched_true_boxes:
759
  class_fn[model_dict[true_label]] += 1
760
 
 
761
  def pred_4_evaluation_per_class(model, loader, score_threshold=0.5, iou_threshold=0.5):
 
 
 
 
 
 
 
 
 
 
 
 
762
  model.eval()
763
  with torch.no_grad():
764
  for images, targets_im in tqdm(loader, desc="Testing... "):
@@ -788,7 +1071,21 @@ def pred_4_evaluation_per_class(model, loader, score_threshold=0.5, iou_threshol
788
 
789
  yield pred_boxes, true_boxes, pred_labels, true_labels
790
 
 
791
  def evaluate_model_by_class(model, test_loader, model_dict, score_threshold=0.5, iou_threshold=0.5):
 
 
 
 
 
 
 
 
 
 
 
 
 
792
  class_tp = {cls: 0 for cls in model_dict.values()}
793
  class_fp = {cls: 0 for cls in model_dict.values()}
794
  class_fn = {cls: 0 for cls in model_dict.values()}
@@ -809,4 +1106,4 @@ def evaluate_model_by_class(model, test_loader, model_dict, score_threshold=0.5,
809
  class_recall[cls] = recall
810
  class_f1_score[cls] = f1_score
811
 
812
- return class_precision, class_recall, class_f1_score
 
9
 
10
 
11
  def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
12
+ """
13
+ Perform non-maximum suppression to filter out overlapping bounding boxes.
14
+
15
+ Parameters:
16
+ - boxes (array): Array of bounding boxes.
17
+ - scores (array): Array of confidence scores for each bounding box.
18
+ - labels (array, optional): Array of labels for each bounding box.
19
+ - iou_threshold (float): Intersection-over-Union threshold to use for filtering.
20
+
21
+ Returns:
22
+ - list: Indices of selected boxes after suppression.
23
+ """
24
  exception = ['pool', 'lane']
25
 
26
  idxs = np.argsort(scores) # Sort the boxes according to their scores in ascending order
 
52
 
53
 
54
  def keypoint_correction(keypoints, boxes, labels, model_dict=arrow_dict, distance_treshold=15):
55
+ """
56
+ Correct keypoints that are too close together by adjusting their positions.
57
+
58
+ Parameters:
59
+ - keypoints (array): Array of keypoints.
60
+ - boxes (array): Array of bounding boxes.
61
+ - labels (array): Array of labels for each bounding box.
62
+ - model_dict (dict): Dictionary mapping model labels to indices.
63
+ - distance_treshold (int): Distance threshold below which keypoints are considered too close.
64
+
65
+ Returns:
66
+ - array: Corrected keypoints.
67
+ """
68
  for idx, (key1, key2) in enumerate(keypoints):
69
  if labels[idx] not in [list(model_dict.values()).index('sequenceFlow'),
70
  list(model_dict.values()).index('messageFlow'),
 
74
  distance = np.linalg.norm(key1[:2] - key2[:2])
75
  if distance < distance_treshold:
76
  print('Key modified for index:', idx)
77
+ x_new, y_new, x, y = find_other_keypoint(idx, keypoints, boxes)
78
+ keypoints[idx][0][:2] = [x_new, y_new]
79
+ keypoints[idx][1][:2] = [x, y]
80
 
81
  return keypoints
82
 
83
 
84
  def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
85
+ """
86
+ Perform object detection prediction using the model.
87
+
88
+ Parameters:
89
+ - model (torch.nn.Module): The object detection model.
90
+ - image (torch.Tensor): The input image.
91
+ - score_threshold (float): Score threshold for filtering predictions.
92
+ - iou_threshold (float): IoU threshold for non-maximum suppression.
93
+
94
+ Returns:
95
+ - numpy.array, dict: The processed image and the prediction dictionary containing 'boxes', 'scores', and 'labels'.
96
+ """
97
  model.eval()
98
  with torch.no_grad():
99
  image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
 
110
 
111
  selected_boxes = non_maximum_suppression(boxes, scores, labels=labels, iou_threshold=iou_threshold)
112
 
113
+ # Find orientation of the task by checking the size of all the boxes and delete the ones that are not in the same orientation
114
  vertical = 0
115
  for i in range(len(labels)):
116
  if labels[i] != list(object_dict.values()).index('task'):
 
124
 
125
  if vertical < horizontal:
126
  if is_vertical(boxes[i]):
127
+ # Find the element in the list and remove it
128
  if i in selected_boxes:
129
  selected_boxes.remove(i)
130
  elif vertical > horizontal:
131
  if is_vertical(boxes[i]) == False:
132
+ # Find the element in the list and remove it
133
  if i in selected_boxes:
134
  selected_boxes.remove(i)
135
  else:
 
139
  scores = scores[selected_boxes]
140
  labels = labels[selected_boxes]
141
 
142
+ # Find the outlier objects that are too small by the area
143
+ obj_not_too_small = find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, element_ref=['event', 'messageEvent'], mode="lower")
144
+ obj_not_too_big = find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=2, element_ref=['task'], mode="upper")
145
 
146
  selected_object = [i for i in range(len(labels)) if i in obj_not_too_small and i in obj_not_too_big]
147
 
 
 
148
  boxes = boxes[selected_object]
149
  scores = scores[selected_object]
150
  labels = labels[selected_object]
151
 
152
+ # Modify the label of the sub-process to task
153
  for i in range(len(labels)):
154
  if labels[i] == list(object_dict.values()).index('subProcess'):
155
  labels[i] = list(object_dict.values()).index('task')
156
+ # Delete all lane and also the value in the labels and scores
157
  lane_index = [i for i in range(len(labels)) if labels[i] == list(object_dict.values()).index('lane')]
158
  boxes = np.delete(boxes, lane_index, axis=0)
159
  labels = np.delete(labels, lane_index)
 
172
 
173
 
174
  def arrow_prediction(model, image, score_threshold=0.5, iou_threshold=0.5, distance_treshold=15):
175
+ """
176
+ Perform arrow detection prediction using the model.
177
+
178
+ Parameters:
179
+ - model (torch.nn.Module): The arrow detection model.
180
+ - image (torch.Tensor): The input image.
181
+ - score_threshold (float): Score threshold for filtering predictions.
182
+ - iou_threshold (float): IoU threshold for non-maximum suppression.
183
+ - distance_treshold (int): Distance threshold for keypoint correction.
184
+
185
+ Returns:
186
+ - numpy.array, dict: The processed image and the prediction dictionary containing 'boxes', 'scores', 'labels', and 'keypoints'.
187
+ """
188
  model.eval()
189
  with torch.no_grad():
190
  image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
 
221
 
222
  return image, prediction
223
 
224
+
225
  def mix_predictions(objects_pred, arrow_pred):
226
+ """
227
+ Combine object and arrow predictions into a single set of predictions.
228
+
229
+ Parameters:
230
+ - objects_pred (dict): Object predictions dictionary.
231
+ - arrow_pred (dict): Arrow predictions dictionary.
232
+
233
+ Returns:
234
+ - tuple: Combined boxes, labels, scores, and keypoints.
235
+ """
236
  # Initialize the list of lists for keypoints
237
  object_keypoints = []
238
 
 
245
  keypoints = [[0, 0, 0], [0, 0, 0]]
246
  object_keypoints.append(keypoints)
247
 
248
+ # Concatenate the two predictions
249
  if len(arrow_pred['boxes']) == 0:
250
  return objects_pred['boxes'], objects_pred['labels'], objects_pred['scores'], object_keypoints
251
 
 
258
 
259
 
260
  def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_threshold=0.6):
261
+ """
262
+ Regroup elements by pool based on IoU and proximity.
263
+
264
+ Parameters:
265
+ - boxes (array): Array of bounding boxes.
266
+ - labels (array): Array of labels for each bounding box.
267
+ - scores (array): Array of confidence scores for each bounding box.
268
+ - keypoints (array): Array of keypoints.
269
+ - class_dict (dict): Dictionary mapping class names to indices.
270
+ - iou_threshold (float): IoU threshold for grouping.
271
+
272
+ Returns:
273
+ - dict: Dictionary grouping elements by pool.
274
+ - array: Updated arrays of boxes, labels, scores, and keypoints.
275
+ """
276
  pool_dict = {}
277
 
278
  # Filter out pools with IoU greater than the threshold
 
339
  return pool_dict, boxes, labels, scores, keypoints
340
 
341
 
 
342
  def create_links(keypoints, boxes, labels, class_dict):
343
+ """
344
+ Create links between elements based on keypoints.
345
+
346
+ Parameters:
347
+ - keypoints (array): Array of keypoints.
348
+ - boxes (array): Array of bounding boxes.
349
+ - labels (array): Array of labels for each bounding box.
350
+ - class_dict (dict): Dictionary mapping class names to indices.
351
+
352
+ Returns:
353
+ - list: List of links between elements.
354
+ - list: List of best points for each link.
355
+ """
356
  best_points = []
357
  links = []
358
  for i in range(len(labels)):
359
+ if labels[i] == list(class_dict.values()).index('sequenceFlow') or labels[i] == list(class_dict.values()).index('messageFlow'):
360
  closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
361
  closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
362
 
 
364
  best_points.append([point_start, point_end])
365
  links.append([closest1, closest2])
366
  else:
367
+ best_points.append([None, None])
368
+ links.append([None, None])
369
 
370
  for i in range(len(labels)):
371
+ if labels[i] == list(class_dict.values()).index('dataAssociation'):
372
  closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
373
  closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
374
  if closest1 is not None and closest2 is not None:
 
377
 
378
  return links, best_points
379
 
380
+
381
  def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
382
+ """
383
+ Correct labels based on the relationships between elements and pools.
384
+
385
+ Parameters:
386
+ - boxes (array): Array of bounding boxes.
387
+ - labels (array): Array of labels for each bounding box.
388
+ - class_dict (dict): Dictionary mapping class names to indices.
389
+ - pool_dict (dict): Dictionary grouping elements by pool.
390
+ - flow_links (list): List of links between elements.
391
+
392
+ Returns:
393
+ - array: Corrected labels.
394
+ - list: Updated flow links.
395
+ """
396
  sequence_flow_index = list(class_dict.values()).index('sequenceFlow')
397
  message_flow_index = list(class_dict.values()).index('messageFlow')
398
  data_association_index = list(class_dict.values()).index('dataAssociation')
 
440
  return labels, flow_links
441
 
442
 
443
+ def find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, element_ref=['event', 'messageEvent'], mode="lower"):
444
+ """
445
+ Identify outlier objects based on their area.
446
+
447
+ Parameters:
448
+ - boxes (array): Array of bounding boxes.
449
+ - labels (array): Array of labels for each bounding box.
450
+ - class_dict (dict): Dictionary mapping class names to indices.
451
+ - std_factor (float): Standard deviation factor for determining outliers.
452
+ - element_ref (list): List of reference elements for calculating area statistics.
453
+ - mode (str): Mode to identify outliers ('lower', 'upper', or 'both').
454
+
455
+ Returns:
456
+ - list: Indices of kept objects that are not outliers.
457
+ """
458
  # Filter out the sizes of events, data objects, and message events
459
  event_indices = [i for i, label in enumerate(labels) if class_dict[label] in element_ref]
460
  event_boxes = [boxes[i] for i in event_indices]
 
475
  kept_indices = []
476
 
477
  if mode == "lower" or mode == 'both':
478
+ # Check for objects that could be too small
479
  for idx, (box, label) in enumerate(zip(boxes, labels)):
480
  area = (box[2] - box[0]) * (box[3] - box[1])
481
  if not (area_lower_threshold <= area):
 
485
  kept_indices.append(idx)
486
 
487
  if mode == "upper" or mode == 'both':
488
+ # Check for objects that could be too big
489
  for idx, (box, label) in enumerate(zip(boxes, labels)):
490
  if label == list(class_dict.values()).index('pool') or label == list(class_dict.values()).index('lane'):
491
  kept_indices.append(idx)
 
497
  else:
498
  kept_indices.append(idx)
499
 
 
500
  return kept_indices
501
 
502
 
 
503
  def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict, limit_area=10000):
504
+ """
505
+ Perform final corrections on the predictions by deleting irrelevant or small pools and duplicate elements.
506
+
507
+ Parameters:
508
+ - boxes (array): Array of bounding boxes.
509
+ - labels (array): Array of labels for each bounding box.
510
+ - scores (array): Array of confidence scores for each bounding box.
511
+ - keypoints (array): Array of keypoints.
512
+ - bpmn_id (list): List of BPMN IDs.
513
+ - links (list): List of links between elements.
514
+ - best_points (list): List of best points for each link.
515
+ - pool_dict (dict): Dictionary grouping elements by pool.
516
+ - limit_area (int): Minimum area threshold for pools.
517
+
518
+ Returns:
519
+ - tuple: Corrected arrays of boxes, labels, scores, keypoints, BPMN IDs, links, best points, and pool dictionary.
520
+ """
521
+ # Delete pools that have only messageFlow on it
522
  delete_pool = []
523
  for pool_index, elements in pool_dict.items():
524
+ # Find the position of the pool_index in the bpmn_id
525
  if pool_index in bpmn_id:
526
  position = bpmn_id.index(pool_index)
527
  else:
 
534
  delete_pool.append(position)
535
  print(f"Pool {pool_index} contains only arrow elements, deleting it")
536
 
537
+ # Calculate the area of the pool
538
  if position < len(boxes):
539
  pool = boxes[position]
540
  area = (pool[2] - pool[0]) * (pool[3] - pool[1])
541
+ if len(pool_dict) > 1 and area < limit_area:
542
  delete_pool.append(position)
543
  print(f"Pool {pool_index} is too small, deleting it")
544
 
 
546
  delete_pool.append(position)
547
  print(f"Pool {position} is vertical, deleting it")
548
 
 
549
  delete_elements = []
550
  # Check if there is an arrow that has the same links
551
  for i in range(len(labels)):
552
+ for j in range(i + 1, len(labels)):
553
  if labels[i] == list(class_dict.values()).index('sequenceFlow') and labels[j] == list(class_dict.values()).index('sequenceFlow'):
554
  if links[i] == links[j]:
555
+ print(f'Element {i} and {j} have the same links')
556
  if scores[i] > scores[j]:
557
+ print('Delete element', j)
558
  delete_elements.append(j)
559
  else:
560
+ print('Delete element', i)
561
  delete_elements.append(i)
562
 
563
+ # Concatenate the delete_elements and the delete_pool
 
 
 
 
 
 
 
 
 
 
564
  delete_elements = delete_elements + delete_pool
565
+ # Delete double value in delete_elements
566
  delete_elements = list(set(delete_elements))
567
 
568
  boxes = np.delete(boxes, delete_elements, axis=0)
 
574
  best_points = [point for i, point in enumerate(best_points) if i not in delete_elements]
575
 
576
  for i in range(len(delete_pool)):
577
+ # Find the bpmn_id of the pool
578
  pool_index = bpmn_id[delete_pool[i]]
579
+ # Delete the pool_index in pool_dict
580
  del pool_dict[pool_index]
581
 
582
  bpmn_id = [point for i, point in enumerate(bpmn_id) if i not in delete_elements]
583
 
584
+ # Also delete the element in the pool_dict
585
  for pool_index, elements in pool_dict.items():
586
  pool_dict[pool_index] = [i for i in elements if i not in delete_elements]
587
 
588
  return boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict
589
 
590
+
591
  def give_link_to_element(links, labels):
592
+ """
593
+ Assign links to elements to create BPMN IDs for events.
594
+
595
+ Parameters:
596
+ - links (list): List of links between elements.
597
+ - labels (array): Array of labels for each bounding box.
598
+
599
+ Returns:
600
+ - list: Updated list of links with assigned links for events.
601
+ """
602
+ # Give a link to event to allow the creation of the BPMN ID with start, intermediate, and end event
603
+ for i in range(len(links)):
604
+ if labels[i] == list(class_dict.values()).index('sequenceFlow'):
605
+ id1, id2 = links[i]
606
+ if (id1 and id2) is not None:
607
+ links[id1][1] = i
608
+ links[id2][0] = i
609
+ return links
610
 
611
 
612
  def generate_data(image, boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict):
613
+ """
614
+ Generate a data dictionary containing image and prediction information.
615
+
616
+ Parameters:
617
+ - image (numpy.array): The input image.
618
+ - boxes (array): Array of bounding boxes.
619
+ - labels (array): Array of labels for each bounding box.
620
+ - scores (array): Array of confidence scores for each bounding box.
621
+ - keypoints (array): Array of keypoints.
622
+ - bpmn_id (list): List of BPMN IDs.
623
+ - flow_links (list): List of links between elements.
624
+ - best_points (list): List of best points for each link.
625
+ - pool_dict (dict): Dictionary grouping elements by pool.
626
+
627
+ Returns:
628
+ - dict: Data dictionary containing all prediction information.
629
+ """
630
+ idx = []
631
+ for i in range(len(labels)):
632
+ idx.append(i)
633
+
634
+ data = {
635
+ 'image': image,
636
+ 'idx': idx,
637
+ 'boxes': boxes,
638
+ 'labels': labels,
639
+ 'scores': scores,
640
+ 'keypoints': keypoints,
641
+ 'links': flow_links,
642
+ 'best_points': best_points,
643
+ 'pool_dict': pool_dict,
644
+ 'BPMN_id': bpmn_id,
645
+ }
646
+
647
+ return data
648
 
 
649
 
650
+ def develop_prediction(boxes, labels, scores, keypoints, class_dict):
651
+ """
652
+ Develop predictions by regrouping elements, creating links, and correcting labels.
653
+
654
+ Parameters:
655
+ - boxes (array): Array of bounding boxes.
656
+ - labels (array): Array of labels for each bounding box.
657
+ - scores (array): Array of confidence scores for each bounding box.
658
+ - keypoints (array): Array of keypoints.
659
+ - class_dict (dict): Dictionary mapping class names to indices.
660
+
661
+ Returns:
662
+ - tuple: Developed prediction components including boxes, labels, scores, keypoints, BPMN IDs, flow links, best points, and pool dictionary.
663
+ """
664
  pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict)
665
 
666
+ bpmn_id, pool_dict = create_BPMN_id(labels, pool_dict)
667
 
668
  # Create links between elements
669
  flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
670
 
671
+ # Correct the labels of some sequenceFlow that cross multiple pools
672
  labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
673
 
674
+ # Give a link to event to allow the creation of the BPMN ID with start, intermediate, and end event
675
  flow_links = give_link_to_element(flow_links, labels)
676
 
677
+ boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = last_correction(
678
+ boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict
679
+ )
680
 
681
+ return boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict
682
 
 
683
 
684
  def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_threshold=0.5, resize=True, distance_treshold=15):
685
+ """
686
+ Perform a full prediction by combining object and arrow models and generating data.
687
+
688
+ Parameters:
689
+ - model_object (torch.nn.Module): The object detection model.
690
+ - model_arrow (torch.nn.Module): The arrow detection model.
691
+ - image (torch.Tensor): The input image.
692
+ - score_threshold (float): Score threshold for filtering predictions.
693
+ - iou_threshold (float): IoU threshold for non-maximum suppression.
694
+ - resize (bool): Flag indicating whether to resize the image.
695
+ - distance_treshold (int): Distance threshold for keypoint correction.
696
+
697
+ Returns:
698
+ - numpy.array, dict: The processed image and the data dictionary containing prediction information.
699
+ """
700
  model_object.eval() # Set the model to evaluation mode
701
  model_arrow.eval() # Set the model to evaluation mode
702
 
 
709
 
710
  boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
711
 
712
+ boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction(
713
+ boxes, labels, scores, keypoints, class_dict
714
+ )
715
 
716
  image = image.permute(1, 2, 0).cpu().numpy()
717
  image = (image * 255).astype(np.uint8)
 
720
 
721
  return image, data
722
 
723
+
724
  def evaluate_model_by_class(pred_boxes, true_boxes, pred_labels, true_labels, model_dict, iou_threshold=0.5):
725
+ """
726
+ Evaluate the model's performance on a per-class basis.
727
+
728
+ Parameters:
729
+ - pred_boxes (array): Predicted bounding boxes.
730
+ - true_boxes (array): Ground truth bounding boxes.
731
+ - pred_labels (array): Predicted labels.
732
+ - true_labels (array): Ground truth labels.
733
+ - model_dict (dict): Dictionary mapping model labels to indices.
734
+ - iou_threshold (float): IoU threshold for determining matches.
735
+
736
+ Returns:
737
+ - tuple: Precision, recall, and F1-score per class.
738
+ """
739
  # Initialize dictionaries to hold per-class counts
740
  class_tp = {cls: 0 for cls in model_dict.values()}
741
  class_fp = {cls: 0 for cls in model_dict.values()}
 
779
  return class_precision, class_recall, class_f1_score
780
 
781
 
782
+ def keypoints_measure(pred_boxes, pred_box, true_boxes, true_box, pred_keypoints, true_keypoints, distance_threshold=5):
783
+ """
784
+ Measure the accuracy of predicted keypoints compared to true keypoints.
785
+
786
+ Parameters:
787
+ - pred_boxes (array): Predicted bounding boxes.
788
+ - pred_box (array): Single predicted bounding box.
789
+ - true_boxes (array): Ground truth bounding boxes.
790
+ - true_box (array): Single ground truth bounding box.
791
+ - pred_keypoints (array): Predicted keypoints.
792
+ - true_keypoints (array): Ground truth keypoints.
793
+ - distance_threshold (int): Distance threshold for considering a keypoint match.
794
+
795
+ Returns:
796
+ - tuple: Number of correct keypoints and whether the keypoints are reverted.
797
+ """
798
  result = 0
799
  reverted = False
800
+ # Find the position of keypoints in the list
801
  idx = np.where(pred_boxes == pred_box)[0][0]
802
  idx2 = np.where(true_boxes == true_box)[0][0]
803
 
 
820
 
821
  return result, reverted
822
 
823
+
824
  def evaluate_single_image(pred_boxes, true_boxes, pred_labels, true_labels, pred_keypoints, true_keypoints, iou_threshold=0.5, distance_threshold=5):
825
+ """
826
+ Evaluate a single image's predictions against the ground truth.
827
+
828
+ Parameters:
829
+ - pred_boxes (array): Predicted bounding boxes.
830
+ - true_boxes (array): Ground truth bounding boxes.
831
+ - pred_labels (array): Predicted labels.
832
+ - true_labels (array): Ground truth labels.
833
+ - pred_keypoints (array): Predicted keypoints.
834
+ - true_keypoints (array): Ground truth keypoints.
835
+ - iou_threshold (float): IoU threshold for determining matches.
836
+ - distance_threshold (int): Distance threshold for considering a keypoint match.
837
+
838
+ Returns:
839
+ - tuple: True positives, false positives, false negatives, correct labels, incorrect labels, correct keypoints, incorrect keypoints, and reverted keypoints count.
840
+ """
841
  tp, fp, fn = 0, 0, 0
842
  key_t, key_f = 0, 0
843
  labels_t, labels_f = 0, 0
 
852
  iou_val = iou(pred_box, true_box)
853
  if iou_val >= iou_threshold:
854
  if true_keypoints is not None and pred_keypoints is not None:
855
+ key_result, reverted = keypoints_measure(
856
+ pred_boxes, pred_box, true_boxes, true_box, pred_keypoints, true_keypoints, distance_threshold
857
+ )
858
  key_t += key_result
859
  key_f += 2 - key_result
860
  if reverted:
 
877
 
878
 
879
  def pred_4_evaluation(model, loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type='object'):
880
+ """
881
+ Evaluate the model on a dataset using predictions for evaluation.
882
+
883
+ Parameters:
884
+ - model (torch.nn.Module): The model to evaluate.
885
+ - loader (torch.utils.data.DataLoader): DataLoader for the dataset.
886
+ - score_threshold (float): Score threshold for filtering predictions.
887
+ - iou_threshold (float): IoU threshold for determining matches.
888
+ - distance_threshold (int): Distance threshold for considering a keypoint match.
889
+ - key_correction (bool): Whether to apply keypoint correction.
890
+ - model_type (str): Type of model ('object' or 'arrow').
891
+
892
+ Returns:
893
+ - tuple: Evaluation results including true positives, false positives, false negatives, correct labels, incorrect labels, correct keypoints, incorrect keypoints, and reverted keypoints count.
894
+ """
895
  model.eval()
896
  tp, fp, fn = 0, 0, 0
897
  labels_t, labels_f = 0, 0
 
929
  filtered_labels = []
930
  filtered_keypoints = []
931
  if 'keypoints' not in prediction:
932
+ # Create a list of zeros of length equal to the number of boxes
933
  pred_keypoints = [np.zeros((2, 3)) for _ in range(len(pred_boxes))]
934
 
935
  for box, score, label, keypoints in zip(pred_boxes, scores, pred_labels, pred_keypoints):
 
946
  filtered_keypoints = None
947
  true_keypoints = None
948
  tp_img, fp_img, fn_img, labels_t_img, labels_f_img, key_t_img, key_f_img, reverted_img = evaluate_single_image(
949
+ filtered_boxes, true_boxes, filtered_labels, true_labels, filtered_keypoints, true_keypoints, iou_threshold, distance_threshold
950
+ )
951
 
952
  tp += tp_img
953
  fp += fp_img
 
960
 
961
  return tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted
962
 
 
963
 
964
+ def main_evaluation(model, test_loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type='object'):
965
+ """
966
+ Main function to evaluate the model on the test dataset.
967
+
968
+ Parameters:
969
+ - model (torch.nn.Module): The model to evaluate.
970
+ - test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
971
+ - score_threshold (float): Score threshold for filtering predictions.
972
+ - iou_threshold (float): IoU threshold for determining matches.
973
+ - distance_threshold (int): Distance threshold for considering a keypoint match.
974
+ - key_correction (bool): Whether to apply keypoint correction.
975
+ - model_type (str): Type of model ('object' or 'arrow').
976
+
977
+ Returns:
978
+ - tuple: Precision, recall, F1-score, key accuracy, and reverted accuracy.
979
+ """
980
+ tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted = pred_4_evaluation(
981
+ model, test_loader, score_threshold, iou_threshold, distance_threshold, key_correction, model_type
982
+ )
983
 
984
  labels_precision = labels_t / (labels_t + labels_f) if (labels_t + labels_f) > 0 else 0
985
  precision = tp / (tp + fp) if (tp + fp) > 0 else 0
 
995
  return labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy
996
 
997
 
 
998
  def evaluate_model_by_class_single_image(pred_boxes, true_boxes, pred_labels, true_labels, class_tp, class_fp, class_fn, model_dict, iou_threshold=0.5):
999
+ """
1000
+ Evaluate a single image's predictions on a per-class basis.
1001
+
1002
+ Parameters:
1003
+ - pred_boxes (array): Predicted bounding boxes.
1004
+ - true_boxes (array): Ground truth bounding boxes.
1005
+ - pred_labels (array): Predicted labels.
1006
+ - true_labels (array): Ground truth labels.
1007
+ - class_tp (dict): Dictionary of true positive counts per class.
1008
+ - class_fp (dict): Dictionary of false positive counts per class.
1009
+ - class_fn (dict): Dictionary of false negative counts per class.
1010
+ - model_dict (dict): Dictionary mapping model labels to indices.
1011
+ - iou_threshold (float): IoU threshold for determining matches.
1012
+ """
1013
  matched_true_boxes = set()
1014
  for pred_idx, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)):
1015
  match_found = False
 
1028
  if idx not in matched_true_boxes:
1029
  class_fn[model_dict[true_label]] += 1
1030
 
1031
+
1032
  def pred_4_evaluation_per_class(model, loader, score_threshold=0.5, iou_threshold=0.5):
1033
+ """
1034
+ Generate predictions for evaluation on a per-class basis.
1035
+
1036
+ Parameters:
1037
+ - model (torch.nn.Module): The model to evaluate.
1038
+ - loader (torch.utils.data.DataLoader): DataLoader for the dataset.
1039
+ - score_threshold (float): Score threshold for filtering predictions.
1040
+ - iou_threshold (float): IoU threshold for determining matches.
1041
+
1042
+ Yields:
1043
+ - tuple: Predicted and true boxes and labels for each batch.
1044
+ """
1045
  model.eval()
1046
  with torch.no_grad():
1047
  for images, targets_im in tqdm(loader, desc="Testing... "):
 
1071
 
1072
  yield pred_boxes, true_boxes, pred_labels, true_labels
1073
 
1074
+
1075
  def evaluate_model_by_class(model, test_loader, model_dict, score_threshold=0.5, iou_threshold=0.5):
1076
+ """
1077
+ Evaluate the model's performance on a per-class basis for the entire dataset.
1078
+
1079
+ Parameters:
1080
+ - model (torch.nn.Module): The model to evaluate.
1081
+ - test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
1082
+ - model_dict (dict): Dictionary mapping model labels to indices.
1083
+ - score_threshold (float): Score threshold for filtering predictions.
1084
+ - iou_threshold (float): IoU threshold for determining matches.
1085
+
1086
+ Returns:
1087
+ - tuple: Precision, recall, and F1-score per class.
1088
+ """
1089
  class_tp = {cls: 0 for cls in model_dict.values()}
1090
  class_fp = {cls: 0 for cls in model_dict.values()}
1091
  class_fn = {cls: 0 for cls in model_dict.values()}
 
1106
  class_recall[cls] = recall
1107
  class_f1_score[cls] = f1_score
1108
 
1109
+ return class_precision, class_recall, class_f1_score
modules/streamlit_utils.py CHANGED
@@ -15,46 +15,64 @@ from modules.display import draw_stream
15
  from modules.eval import full_prediction
16
  from modules.train import get_faster_rcnn_model, get_arrow_model
17
  from streamlit_image_comparison import image_comparison
18
-
19
  from streamlit_image_annotation import detection
20
  from modules.toXML import create_XML
21
  from modules.eval import develop_prediction, generate_data
22
  from modules.utils import class_dict, object_dict
23
-
24
  from modules.htlm_webpage import display_bpmn_xml
25
  from streamlit_cropper import st_cropper
26
  from streamlit_image_select import image_select
27
  from streamlit_js_eval import streamlit_js_eval
28
-
29
  from modules.toWizard import create_wizard_file
30
  from huggingface_hub import hf_hub_download
31
  import time
32
-
33
  from modules.toXML import get_size_elements
34
 
35
-
36
  def get_memory_usage():
 
 
 
37
  process = psutil.Process()
38
  mem_info = process.memory_info()
39
  return mem_info.rss / (1024 ** 2) # Return memory usage in MB
40
 
 
41
  def clear_memory():
 
 
 
42
  st.session_state.clear()
43
  gc.collect()
44
 
45
-
46
  # Function to read XML content from a file
47
  def read_xml_file(filepath):
48
- """ Read XML content from a file """
 
 
 
 
 
 
 
 
49
  with open(filepath, 'r', encoding='utf-8') as file:
50
  return file.read()
51
 
52
-
53
  # Suppress the symlink warning
54
  os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
55
 
56
  # Function to load the models only once and use session state to keep track of it
57
  def load_models():
 
 
 
 
 
 
 
 
 
58
  with st.spinner('Loading model...'):
59
  model_object = get_faster_rcnn_model(len(object_dict))
60
  model_arrow = get_arrow_model(len(arrow_dict), 2)
@@ -71,7 +89,6 @@ def load_models():
71
 
72
  # Load model arrow
73
  if not Path(output_arrow).exists():
74
- # Download model from Hugging Face Hub
75
  model_arrow.load_state_dict(torch.load(model_arrow_path, map_location=device))
76
  st.session_state.model_arrow = model_arrow
77
  print('Model arrow downloaded from Hugging Face Hub')
@@ -82,22 +99,18 @@ def load_models():
82
  print()
83
  st.session_state.model_arrow = model_arrow
84
  print('Model arrow loaded from local file')
85
-
86
 
87
  # Load model object
88
  if not Path(output_object).exists():
89
- # Download model from Hugging Face Hub
90
  model_object.load_state_dict(torch.load(model_object_path, map_location=device))
91
  st.session_state.model_object = model_object
92
  print('Model object downloaded from Hugging Face Hub')
93
- # Save the model locally
94
  torch.save(model_object.state_dict(), output_object)
95
  elif 'model_object' not in st.session_state and Path(output_object).exists():
96
  model_object.load_state_dict(torch.load(output_object, map_location=device))
97
  print()
98
  st.session_state.model_object = model_object
99
- print('Model object loaded from local file\n')
100
-
101
 
102
  # Move models to device
103
  model_arrow.to(device)
@@ -110,6 +123,17 @@ def load_models():
110
 
111
  # Function to prepare the image for processing
112
  def prepare_image(image, pad=True, new_size=(1333, 1333)):
 
 
 
 
 
 
 
 
 
 
 
113
  original_size = image.size
114
  # Calculate scale to fit the new size while maintaining aspect ratio
115
  scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
@@ -128,6 +152,15 @@ def prepare_image(image, pad=True, new_size=(1333, 1333)):
128
 
129
  # Function to display various options for image annotation
130
  def display_options(image, score_threshold, is_mobile, screen_width):
 
 
 
 
 
 
 
 
 
131
  col1, col2, col3, col4, col5 = st.columns(5)
132
  with col1:
133
  write_class = st.toggle("Write Class", value=True)
@@ -157,7 +190,7 @@ def display_options(image, score_threshold, is_mobile, screen_width):
157
  if is_mobile is True:
158
  width = screen_width
159
  else:
160
- width = screen_width//2
161
 
162
  # Display the original and annotated images side by side
163
  image_comparison(
@@ -171,8 +204,25 @@ def display_options(image, score_threshold, is_mobile, screen_width):
171
 
172
  # Function to perform inference on the uploaded image using the loaded models
173
  def perform_inference(model_object, model_arrow, image, score_threshold, is_mobile, screen_width, iou_threshold=0.5, distance_treshold=30, percentage_text_dist_thresh=0.5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  uploaded_image = prepare_image(image, pad=False)
175
-
176
  img_tensor = F.to_tensor(prepare_image(image.convert('RGB')))
177
 
178
  # Display original image
@@ -181,7 +231,7 @@ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobi
181
  if is_mobile is False:
182
  width = screen_width
183
  if is_mobile is False:
184
- width = screen_width//2
185
  image_placeholder.image(uploaded_image, caption='Original Image', width=width)
186
 
187
  # Perform OCR on the uploaded image
@@ -193,9 +243,9 @@ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobi
193
  # Prediction
194
  _, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
195
 
196
- #Mapping text to prediction
197
  st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
198
-
199
  # Remove the original image display
200
  image_placeholder.empty()
201
 
@@ -204,24 +254,44 @@ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobi
204
 
205
  return image, st.session_state.prediction, st.session_state.text_mapping
206
 
 
207
  @st.cache_data
208
  def get_image(uploaded_file):
 
 
 
 
 
 
 
 
 
209
  return Image.open(uploaded_file).convert('RGB')
210
 
211
-
212
  def configure_page():
 
 
 
 
 
 
 
 
213
  st.set_page_config(layout="wide")
214
  screen_width = streamlit_js_eval(js_expressions='screen.width', want_output=True, key='SCR')
215
  is_mobile = screen_width is not None and screen_width < 800
216
  return is_mobile, screen_width
217
 
 
218
  def display_banner(is_mobile):
219
- # JavaScript expression to detect dark mode
220
- dark_mode_js = """
221
- (window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches)
222
  """
 
223
 
224
- # Evaluate JavaScript in Streamlit to check for dark mode
 
 
 
225
  is_dark_mode = streamlit_js_eval(js_expressions=dark_mode_js, key='dark_mode')
226
 
227
  if is_mobile:
@@ -235,16 +305,27 @@ def display_banner(is_mobile):
235
  else:
236
  st.image("./images/banner_desktop.png", use_column_width=True)
237
 
 
238
  def display_title(is_mobile):
 
 
 
 
 
 
239
  title = "Welcome on the BPMN AI model recognition app"
240
  if is_mobile:
241
  title = "Welcome on the mobile version of BPMN AI model recognition app"
242
  st.title(title)
243
 
 
244
  def display_sidebar():
 
 
 
245
  st.sidebar.header("This BPMN AI model recognition is proposed by: \n ELCA in collaboration with EPFL.")
246
  st.sidebar.subheader("Instructions:")
247
- st.sidebar.text("1. Upload you image")
248
  st.sidebar.text("2. Crop the image \n (try to put the BPMN diagram \n in the center of the image)")
249
  st.sidebar.text("3. Set the score threshold for\n prediction (default is 0.5)")
250
  st.sidebar.text("4. Click on 'Launch Prediction'")
@@ -252,20 +333,20 @@ def display_sidebar():
252
  st.sidebar.text("6. You can modify the result \n by clicking on:\n 'Method&Style modification'")
253
  st.sidebar.text("7. You can change the scale for \n the XML file and the size of \n elements (default is 1.0)")
254
  st.sidebar.text("8. You can modify with modeler \n and download the result in \n right format")
255
-
256
  st.sidebar.subheader("If there is an error, try to:")
257
  st.sidebar.text("1. Change the score threshold")
258
  st.sidebar.text("2. Re-crop the image by placing\n the BPMN diagram in the\n center of the image")
259
  st.sidebar.text("3. Re-Launch the prediction")
260
-
261
  st.sidebar.subheader("You can close this sidebar")
262
-
263
  for i in range(5):
264
  st.sidebar.subheader("")
265
-
266
  st.sidebar.subheader("Made with ❤️ by Benjamin.K")
267
 
 
268
  def initialize_session_state():
 
 
 
269
  if 'pool_bboxes' not in st.session_state:
270
  st.session_state.pool_bboxes = []
271
  if 'model_loaded' not in st.session_state:
@@ -275,7 +356,14 @@ def initialize_session_state():
275
  load_models()
276
  st.rerun()
277
 
 
278
  def load_example_image():
 
 
 
 
 
 
279
  with st.expander("Use example images"):
280
  img_selected = image_select(
281
  "If you have no image and just want to test the demo, click on one of these images",
@@ -287,10 +375,20 @@ def load_example_image():
287
  )
288
  return img_selected
289
 
 
290
  def load_user_image(img_selected, is_mobile):
 
 
 
 
 
 
 
 
 
 
291
  if img_selected == './images/none.jpg':
292
  img_selected = None
293
-
294
  if img_selected is not None:
295
  uploaded_file = img_selected
296
  else:
@@ -300,13 +398,23 @@ def load_user_image(img_selected, is_mobile):
300
  col1, col2 = st.columns(2)
301
  with col1:
302
  uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"])
303
-
304
  return uploaded_file
305
 
 
306
  def display_image(uploaded_file, screen_width, is_mobile):
 
 
 
 
 
 
 
 
 
 
 
307
  if 'rotation_angle' not in st.session_state:
308
  st.session_state.rotation_angle = 0 # Initialize the rotation angle in session state
309
-
310
  if 'brightness' not in st.session_state:
311
  st.session_state.brightness = 1.0 # Initialize brightness in session state
312
 
@@ -349,15 +457,23 @@ def display_image(uploaded_file, screen_width, is_mobile):
349
  if not is_mobile:
350
  cropped_image = crop_image(adjusted_image, original_image)
351
  else:
352
- st.image(adjusted_image, caption="Image", use_column_width=False, width=int(4/5 * screen_width))
353
  cropped_image = original_image
354
 
355
  return cropped_image
356
 
357
-
358
-
359
-
360
  def crop_image(resized_image, original_image):
 
 
 
 
 
 
 
 
 
 
361
  marge = 10
362
  cropped_box = st_cropper(
363
  resized_image,
@@ -373,23 +489,50 @@ def crop_image(resized_image, original_image):
373
  cropped_image = original_image.crop((x0, y0, x1, y1))
374
  return cropped_image
375
 
 
376
  def get_score_threshold(is_mobile):
 
 
 
 
 
 
377
  col1, col2 = st.columns(2)
378
  with col1:
379
- st.session_state.score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
380
 
381
  def launch_prediction(cropped_image, score_threshold, is_mobile, screen_width):
 
 
 
 
 
 
 
 
 
 
 
 
382
  st.session_state.crop_image = cropped_image
383
  with st.spinner('Processing...'):
384
- image, _ , _ = perform_inference(
385
  st.session_state.model_object, st.session_state.model_arrow, st.session_state.crop_image,
386
  score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5
387
  )
388
  st.balloons()
389
  return image
390
-
391
 
392
  def modify_results(percentage_text_dist_thresh=0.5):
 
 
 
 
 
 
 
 
 
393
  with st.expander("Method & Style modification"):
394
  label_list = list(object_dict.values())
395
  if st.session_state.prediction['labels'][-1] == 6:
@@ -445,7 +588,6 @@ def modify_results(percentage_text_dist_thresh=0.5):
445
 
446
  object_labels = np.array(object_labels)
447
 
448
-
449
  if len(object_bboxes) == len(bboxes):
450
  # Calculate absolute differences
451
  abs_diff = np.abs(object_bboxes - bboxes)
@@ -456,7 +598,7 @@ def modify_results(percentage_text_dist_thresh=0.5):
456
  changes = True
457
  break
458
 
459
- #check if labels are the same
460
  if not np.array_equal(object_labels, new_lab):
461
  changes = True
462
  else:
@@ -477,7 +619,6 @@ def modify_results(percentage_text_dist_thresh=0.5):
477
  new_scores = np.concatenate((object_scores, arrow_score))
478
  new_keypoints = np.concatenate((object_keypoints, arrow_keypoints))
479
 
480
-
481
  boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, new_scores, new_keypoints, class_dict)
482
 
483
  st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict)
@@ -489,21 +630,35 @@ def modify_results(percentage_text_dist_thresh=0.5):
489
 
490
  return True
491
 
492
-
493
-
494
-
495
  def display_bpmn_modeler(is_mobile, screen_width):
 
 
 
 
 
 
 
496
  with st.spinner('Waiting for BPMN modeler...'):
497
  st.session_state.bpmn_xml = create_XML(
498
  st.session_state.prediction.copy(), st.session_state.text_mapping,
499
  st.session_state.size_scale, st.session_state.scale
500
  )
501
- st.session_state.vizi_file = create_wizard_file(st.session_state.prediction.copy(), st.session_state.text_mapping)
502
 
 
 
503
  display_bpmn_xml(st.session_state.bpmn_xml, st.session_state.vizi_file, is_mobile=is_mobile, screen_width=int(4/5 * screen_width))
504
 
505
-
506
  def find_best_scale(pred, size_elements):
 
 
 
 
 
 
 
 
 
 
507
  boxes = pred['boxes']
508
  labels = pred['labels']
509
 
@@ -535,6 +690,12 @@ def find_best_scale(pred, size_elements):
535
  return best_scale
536
 
537
  def modeler_options(is_mobile):
 
 
 
 
 
 
538
  if not is_mobile:
539
  with st.expander("Options for BPMN modeler"):
540
  col1, col2 = st.columns(2)
@@ -545,4 +706,4 @@ def modeler_options(is_mobile):
545
  st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
546
  else:
547
  st.session_state.scale = 1.0
548
- st.session_state.size_scale = 1.0
 
15
  from modules.eval import full_prediction
16
  from modules.train import get_faster_rcnn_model, get_arrow_model
17
  from streamlit_image_comparison import image_comparison
 
18
  from streamlit_image_annotation import detection
19
  from modules.toXML import create_XML
20
  from modules.eval import develop_prediction, generate_data
21
  from modules.utils import class_dict, object_dict
 
22
  from modules.htlm_webpage import display_bpmn_xml
23
  from streamlit_cropper import st_cropper
24
  from streamlit_image_select import image_select
25
  from streamlit_js_eval import streamlit_js_eval
 
26
  from modules.toWizard import create_wizard_file
27
  from huggingface_hub import hf_hub_download
28
  import time
 
29
  from modules.toXML import get_size_elements
30
 
31
+ # Function to get memory usage
32
  def get_memory_usage():
33
+ """
34
+ Returns the current memory usage of the process in MB.
35
+ """
36
  process = psutil.Process()
37
  mem_info = process.memory_info()
38
  return mem_info.rss / (1024 ** 2) # Return memory usage in MB
39
 
40
+ # Function to clear memory
41
  def clear_memory():
42
+ """
43
+ Clears the Streamlit session state and triggers garbage collection.
44
+ """
45
  st.session_state.clear()
46
  gc.collect()
47
 
 
48
  # Function to read XML content from a file
49
  def read_xml_file(filepath):
50
+ """
51
+ Reads and returns the content of an XML file.
52
+
53
+ Parameters:
54
+ - filepath (str): The path to the XML file.
55
+
56
+ Returns:
57
+ - str: The content of the XML file.
58
+ """
59
  with open(filepath, 'r', encoding='utf-8') as file:
60
  return file.read()
61
 
 
62
  # Suppress the symlink warning
63
  os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
64
 
65
  # Function to load the models only once and use session state to keep track of it
66
  def load_models():
67
+ """
68
+ Loads the object and arrow detection models, either from the local file or
69
+ downloads from the Hugging Face Hub if not available locally. The models
70
+ are stored in the Streamlit session state.
71
+
72
+ Returns:
73
+ - model_object (torch.nn.Module): The loaded object detection model.
74
+ - model_arrow (torch.nn.Module): The loaded arrow detection model.
75
+ """
76
  with st.spinner('Loading model...'):
77
  model_object = get_faster_rcnn_model(len(object_dict))
78
  model_arrow = get_arrow_model(len(arrow_dict), 2)
 
89
 
90
  # Load model arrow
91
  if not Path(output_arrow).exists():
 
92
  model_arrow.load_state_dict(torch.load(model_arrow_path, map_location=device))
93
  st.session_state.model_arrow = model_arrow
94
  print('Model arrow downloaded from Hugging Face Hub')
 
99
  print()
100
  st.session_state.model_arrow = model_arrow
101
  print('Model arrow loaded from local file')
 
102
 
103
  # Load model object
104
  if not Path(output_object).exists():
 
105
  model_object.load_state_dict(torch.load(model_object_path, map_location=device))
106
  st.session_state.model_object = model_object
107
  print('Model object downloaded from Hugging Face Hub')
 
108
  torch.save(model_object.state_dict(), output_object)
109
  elif 'model_object' not in st.session_state and Path(output_object).exists():
110
  model_object.load_state_dict(torch.load(output_object, map_location=device))
111
  print()
112
  st.session_state.model_object = model_object
113
+ print('Model object loaded from local file')
 
114
 
115
  # Move models to device
116
  model_arrow.to(device)
 
123
 
124
  # Function to prepare the image for processing
125
  def prepare_image(image, pad=True, new_size=(1333, 1333)):
126
+ """
127
+ Resizes and optionally pads the input image to a new size.
128
+
129
+ Parameters:
130
+ - image (PIL.Image): The image to be processed.
131
+ - pad (bool): Whether to pad the image to the new size.
132
+ - new_size (tuple): The target size for the image.
133
+
134
+ Returns:
135
+ - PIL.Image: The processed image.
136
+ """
137
  original_size = image.size
138
  # Calculate scale to fit the new size while maintaining aspect ratio
139
  scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
 
152
 
153
  # Function to display various options for image annotation
154
  def display_options(image, score_threshold, is_mobile, screen_width):
155
+ """
156
+ Displays various options for image annotation and draws the annotated image.
157
+
158
+ Parameters:
159
+ - image (PIL.Image): The image to be annotated.
160
+ - score_threshold (float): The score threshold for displaying annotations.
161
+ - is_mobile (bool): Flag indicating if the device is mobile.
162
+ - screen_width (int): The width of the screen.
163
+ """
164
  col1, col2, col3, col4, col5 = st.columns(5)
165
  with col1:
166
  write_class = st.toggle("Write Class", value=True)
 
190
  if is_mobile is True:
191
  width = screen_width
192
  else:
193
+ width = screen_width // 2
194
 
195
  # Display the original and annotated images side by side
196
  image_comparison(
 
204
 
205
  # Function to perform inference on the uploaded image using the loaded models
206
  def perform_inference(model_object, model_arrow, image, score_threshold, is_mobile, screen_width, iou_threshold=0.5, distance_treshold=30, percentage_text_dist_thresh=0.5):
207
+ """
208
+ Performs inference on the uploaded image using the loaded models and updates
209
+ the session state with predictions and text mappings.
210
+
211
+ Parameters:
212
+ - model_object (torch.nn.Module): The object detection model.
213
+ - model_arrow (torch.nn.Module): The arrow detection model.
214
+ - image (PIL.Image): The uploaded image.
215
+ - score_threshold (float): The score threshold for displaying annotations.
216
+ - is_mobile (bool): Flag indicating if the device is mobile.
217
+ - screen_width (int): The width of the screen.
218
+ - iou_threshold (float): The IoU threshold for filtering boxes.
219
+ - distance_treshold (int): The distance threshold for matching keypoints.
220
+ - percentage_text_dist_thresh (float): The percentage distance threshold for text mapping.
221
+
222
+ Returns:
223
+ - tuple: The processed image, prediction, and text mapping.
224
+ """
225
  uploaded_image = prepare_image(image, pad=False)
 
226
  img_tensor = F.to_tensor(prepare_image(image.convert('RGB')))
227
 
228
  # Display original image
 
231
  if is_mobile is False:
232
  width = screen_width
233
  if is_mobile is False:
234
+ width = screen_width // 2
235
  image_placeholder.image(uploaded_image, caption='Original Image', width=width)
236
 
237
  # Perform OCR on the uploaded image
 
243
  # Prediction
244
  _, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
245
 
246
+ # Mapping text to prediction
247
  st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
248
+
249
  # Remove the original image display
250
  image_placeholder.empty()
251
 
 
254
 
255
  return image, st.session_state.prediction, st.session_state.text_mapping
256
 
257
+ # Function to get the image from the uploaded file
258
  @st.cache_data
259
  def get_image(uploaded_file):
260
+ """
261
+ Opens and converts the uploaded image file to RGB format.
262
+
263
+ Parameters:
264
+ - uploaded_file: The uploaded image file.
265
+
266
+ Returns:
267
+ - PIL.Image: The opened and converted image.
268
+ """
269
  return Image.open(uploaded_file).convert('RGB')
270
 
271
+ # Function to configure the Streamlit page
272
  def configure_page():
273
+ """
274
+ Configures the Streamlit page layout and returns the screen width
275
+ and a flag indicating if the device is mobile.
276
+
277
+ Returns:
278
+ - is_mobile (bool): Flag indicating if the device is mobile.
279
+ - screen_width (int): The width of the screen.
280
+ """
281
  st.set_page_config(layout="wide")
282
  screen_width = streamlit_js_eval(js_expressions='screen.width', want_output=True, key='SCR')
283
  is_mobile = screen_width is not None and screen_width < 800
284
  return is_mobile, screen_width
285
 
286
+ # Function to display the banner based on device type and theme
287
  def display_banner(is_mobile):
 
 
 
288
  """
289
+ Displays the appropriate banner image based on device type and dark mode preference.
290
 
291
+ Parameters:
292
+ - is_mobile (bool): Flag indicating if the device is mobile.
293
+ """
294
+ dark_mode_js = "(window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches)"
295
  is_dark_mode = streamlit_js_eval(js_expressions=dark_mode_js, key='dark_mode')
296
 
297
  if is_mobile:
 
305
  else:
306
  st.image("./images/banner_desktop.png", use_column_width=True)
307
 
308
+ # Function to display the title based on device type
309
  def display_title(is_mobile):
310
+ """
311
+ Displays the title of the app based on device type.
312
+
313
+ Parameters:
314
+ - is_mobile (bool): Flag indicating if the device is mobile.
315
+ """
316
  title = "Welcome on the BPMN AI model recognition app"
317
  if is_mobile:
318
  title = "Welcome on the mobile version of BPMN AI model recognition app"
319
  st.title(title)
320
 
321
+ # Function to display the sidebar with instructions and information
322
  def display_sidebar():
323
+ """
324
+ Displays the sidebar with instructions and information about the app.
325
+ """
326
  st.sidebar.header("This BPMN AI model recognition is proposed by: \n ELCA in collaboration with EPFL.")
327
  st.sidebar.subheader("Instructions:")
328
+ st.sidebar.text("1. Upload your image")
329
  st.sidebar.text("2. Crop the image \n (try to put the BPMN diagram \n in the center of the image)")
330
  st.sidebar.text("3. Set the score threshold for\n prediction (default is 0.5)")
331
  st.sidebar.text("4. Click on 'Launch Prediction'")
 
333
  st.sidebar.text("6. You can modify the result \n by clicking on:\n 'Method&Style modification'")
334
  st.sidebar.text("7. You can change the scale for \n the XML file and the size of \n elements (default is 1.0)")
335
  st.sidebar.text("8. You can modify with modeler \n and download the result in \n right format")
 
336
  st.sidebar.subheader("If there is an error, try to:")
337
  st.sidebar.text("1. Change the score threshold")
338
  st.sidebar.text("2. Re-crop the image by placing\n the BPMN diagram in the\n center of the image")
339
  st.sidebar.text("3. Re-Launch the prediction")
 
340
  st.sidebar.subheader("You can close this sidebar")
 
341
  for i in range(5):
342
  st.sidebar.subheader("")
 
343
  st.sidebar.subheader("Made with ❤️ by Benjamin.K")
344
 
345
+ # Function to initialize session state variables
346
  def initialize_session_state():
347
+ """
348
+ Initializes the session state variables for the app.
349
+ """
350
  if 'pool_bboxes' not in st.session_state:
351
  st.session_state.pool_bboxes = []
352
  if 'model_loaded' not in st.session_state:
 
356
  load_models()
357
  st.rerun()
358
 
359
+ # Function to load example images for testing
360
  def load_example_image():
361
+ """
362
+ Loads example images for testing the app and returns the selected image.
363
+
364
+ Returns:
365
+ - str: The path to the selected example image.
366
+ """
367
  with st.expander("Use example images"):
368
  img_selected = image_select(
369
  "If you have no image and just want to test the demo, click on one of these images",
 
375
  )
376
  return img_selected
377
 
378
+ # Function to load user-uploaded images or selected example images
379
  def load_user_image(img_selected, is_mobile):
380
+ """
381
+ Loads the user-uploaded image or the selected example image.
382
+
383
+ Parameters:
384
+ - img_selected (str): The path to the selected example image.
385
+ - is_mobile (bool): Flag indicating if the device is mobile.
386
+
387
+ Returns:
388
+ - str: The path to the uploaded image file.
389
+ """
390
  if img_selected == './images/none.jpg':
391
  img_selected = None
 
392
  if img_selected is not None:
393
  uploaded_file = img_selected
394
  else:
 
398
  col1, col2 = st.columns(2)
399
  with col1:
400
  uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"])
 
401
  return uploaded_file
402
 
403
+ # Function to display the uploaded or example image
404
  def display_image(uploaded_file, screen_width, is_mobile):
405
+ """
406
+ Displays the uploaded or selected example image with options to rotate and adjust brightness.
407
+
408
+ Parameters:
409
+ - uploaded_file: The uploaded image file.
410
+ - screen_width (int): The width of the screen.
411
+ - is_mobile (bool): Flag indicating if the device is mobile.
412
+
413
+ Returns:
414
+ - PIL.Image: The cropped and adjusted image.
415
+ """
416
  if 'rotation_angle' not in st.session_state:
417
  st.session_state.rotation_angle = 0 # Initialize the rotation angle in session state
 
418
  if 'brightness' not in st.session_state:
419
  st.session_state.brightness = 1.0 # Initialize brightness in session state
420
 
 
457
  if not is_mobile:
458
  cropped_image = crop_image(adjusted_image, original_image)
459
  else:
460
+ st.image(adjusted_image, caption="Image", use_column_width=False, width=int(4 / 5 * screen_width))
461
  cropped_image = original_image
462
 
463
  return cropped_image
464
 
465
+ # Function to crop the image
 
 
466
  def crop_image(resized_image, original_image):
467
+ """
468
+ Crops the resized image based on user input.
469
+
470
+ Parameters:
471
+ - resized_image (PIL.Image): The resized image.
472
+ - original_image (PIL.Image): The original image.
473
+
474
+ Returns:
475
+ - PIL.Image: The cropped image.
476
+ """
477
  marge = 10
478
  cropped_box = st_cropper(
479
  resized_image,
 
489
  cropped_image = original_image.crop((x0, y0, x1, y1))
490
  return cropped_image
491
 
492
+ # Function to get the score threshold for prediction
493
  def get_score_threshold(is_mobile):
494
+ """
495
+ Displays a slider to set the score threshold for prediction.
496
+
497
+ Parameters:
498
+ - is_mobile (bool): Flag indicating if the device is mobile.
499
+ """
500
  col1, col2 = st.columns(2)
501
  with col1:
502
+ st.session_state.score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
503
 
504
  def launch_prediction(cropped_image, score_threshold, is_mobile, screen_width):
505
+ """
506
+ Launches the prediction process on the cropped image and displays balloons upon completion.
507
+
508
+ Parameters:
509
+ - cropped_image (PIL.Image): The cropped image to be processed.
510
+ - score_threshold (float): The score threshold for predictions.
511
+ - is_mobile (bool): Flag indicating if the device is mobile.
512
+ - screen_width (int): The width of the screen.
513
+
514
+ Returns:
515
+ - PIL.Image: The image after performing inference.
516
+ """
517
  st.session_state.crop_image = cropped_image
518
  with st.spinner('Processing...'):
519
+ image, _, _ = perform_inference(
520
  st.session_state.model_object, st.session_state.model_arrow, st.session_state.crop_image,
521
  score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5
522
  )
523
  st.balloons()
524
  return image
 
525
 
526
  def modify_results(percentage_text_dist_thresh=0.5):
527
+ """
528
+ Allows the user to modify the results using method and style modification.
529
+
530
+ Parameters:
531
+ - percentage_text_dist_thresh (float): Threshold for mapping text to predictions based on percentage distance.
532
+
533
+ Returns:
534
+ - bool: True if changes are detected and modifications are made, otherwise False.
535
+ """
536
  with st.expander("Method & Style modification"):
537
  label_list = list(object_dict.values())
538
  if st.session_state.prediction['labels'][-1] == 6:
 
588
 
589
  object_labels = np.array(object_labels)
590
 
 
591
  if len(object_bboxes) == len(bboxes):
592
  # Calculate absolute differences
593
  abs_diff = np.abs(object_bboxes - bboxes)
 
598
  changes = True
599
  break
600
 
601
+ # Check if labels are the same
602
  if not np.array_equal(object_labels, new_lab):
603
  changes = True
604
  else:
 
619
  new_scores = np.concatenate((object_scores, arrow_score))
620
  new_keypoints = np.concatenate((object_keypoints, arrow_keypoints))
621
 
 
622
  boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, new_scores, new_keypoints, class_dict)
623
 
624
  st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict)
 
630
 
631
  return True
632
 
 
 
 
633
  def display_bpmn_modeler(is_mobile, screen_width):
634
+ """
635
+ Displays the BPMN modeler with the current prediction and text mapping.
636
+
637
+ Parameters:
638
+ - is_mobile (bool): Flag indicating if the device is mobile.
639
+ - screen_width (int): The width of the screen.
640
+ """
641
  with st.spinner('Waiting for BPMN modeler...'):
642
  st.session_state.bpmn_xml = create_XML(
643
  st.session_state.prediction.copy(), st.session_state.text_mapping,
644
  st.session_state.size_scale, st.session_state.scale
645
  )
 
646
 
647
+ st.session_state.vizi_file = create_wizard_file(st.session_state.prediction.copy(), st.session_state.text_mapping)
648
+
649
  display_bpmn_xml(st.session_state.bpmn_xml, st.session_state.vizi_file, is_mobile=is_mobile, screen_width=int(4/5 * screen_width))
650
 
 
651
  def find_best_scale(pred, size_elements):
652
+ """
653
+ Finds the best scale for the elements in the prediction.
654
+
655
+ Parameters:
656
+ - pred (dict): The prediction data.
657
+ - size_elements (dict): The size elements dictionary.
658
+
659
+ Returns:
660
+ - float: The best scale for the elements.
661
+ """
662
  boxes = pred['boxes']
663
  labels = pred['labels']
664
 
 
690
  return best_scale
691
 
692
  def modeler_options(is_mobile):
693
+ """
694
+ Displays options for the BPMN modeler.
695
+
696
+ Parameters:
697
+ - is_mobile (bool): Flag indicating if the device is mobile.
698
+ """
699
  if not is_mobile:
700
  with st.expander("Options for BPMN modeler"):
701
  col1, col2 = st.columns(2)
 
706
  st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
707
  else:
708
  st.session_state.scale = 1.0
709
+ st.session_state.size_scale = 1.0
modules/toWizard.py CHANGED
@@ -4,13 +4,31 @@ from xml.dom import minidom
4
  from modules.utils import error
5
  from modules.OCR import analyze_sentiment
6
 
7
-
8
  def rescale(scale, boxes):
 
 
 
 
 
 
 
 
 
 
9
  for i in range(len(boxes)):
10
  boxes[i] = [boxes[i][0] * scale, boxes[i][1] * scale, boxes[i][2] * scale, boxes[i][3] * scale]
11
  return boxes
12
 
13
  def create_BPMN_id(data):
 
 
 
 
 
 
 
 
 
14
  enum_end, enum_start, enum_task, enum_sequence, enum_dataflow, enum_messflow, enum_messageEvent, enum_exclusiveGateway, enum_parallelGateway, enum_pool = 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
15
  BPMN_name = [class_dict[data['labels'][i]] for i in range(len(data['labels']))]
16
  for idx, Bpmn_id in enumerate(BPMN_name):
@@ -49,15 +67,35 @@ def create_BPMN_id(data):
49
  return data
50
 
51
  def check_end(link):
 
 
 
 
 
 
 
 
 
52
  if link[1] is None:
53
  return True
54
  return False
55
 
56
  def connect(data, text_mapping, i):
 
 
 
 
 
 
 
 
 
 
 
57
  next_text = []
58
  target_idx = data['links'][i][1]
59
  # Check if the target index is valid
60
- if target_idx==None or target_idx >= len(data['links']):
61
  error('There may be an error with the Vizi file, care when you download it.')
62
  return None, None, None
63
 
@@ -80,11 +118,30 @@ def connect(data, text_mapping, i):
80
  return current_text, next_text, next_id
81
 
82
  def check_start(val):
 
 
 
 
 
 
 
 
 
83
  if val[0] is None:
84
  return True
85
  return False
86
 
87
  def find_merge(bpmn_id, links):
 
 
 
 
 
 
 
 
 
 
88
  merge = []
89
  for idx, link in enumerate(links):
90
  next_element = link[1]
@@ -104,7 +161,7 @@ def find_merge(bpmn_id, links):
104
  if element is None:
105
  merge_elements[idx] = False
106
  continue
107
- #count how many time the element is in the list
108
  count = merge.count(element)
109
  if count > 1:
110
  merge_elements[idx] = True
@@ -114,6 +171,17 @@ def find_merge(bpmn_id, links):
114
  return merge_elements
115
 
116
  def find_positive_end(bpmn_ids, links, text_mapping):
 
 
 
 
 
 
 
 
 
 
 
117
  emotion_data = []
118
  for idx, bpmn_id in enumerate(bpmn_ids):
119
  if idx >= len(links):
@@ -130,6 +198,15 @@ def find_positive_end(bpmn_ids, links, text_mapping):
130
  return sorted_emotions[0][0] if len(sorted_emotions) > 0 else None
131
 
132
  def find_best_direction(texts_list):
 
 
 
 
 
 
 
 
 
133
  emotion_data = []
134
  for text in texts_list:
135
  highest_emotion, highest_score = analyze_sentiment(text)
@@ -141,18 +218,24 @@ def find_best_direction(texts_list):
141
 
142
  return sorted_emotions[0][0] if len(sorted_emotions) > 0 else None
143
 
144
-
145
-
146
  def create_wizard_file(data, text_mapping):
 
 
 
 
 
 
147
 
 
 
 
148
  not_change = ['pool','sequenceFlow','messageFlow','dataAssociation']
149
 
150
- #add a name into the text_mapping when there is no name
151
  for idx, key in enumerate(text_mapping.keys()):
152
  if text_mapping[key] == '' and key.split('_')[0] not in not_change:
153
  text_mapping[key] = f'unnamed_{key}'
154
 
155
-
156
  root = ET.Element('methodAndStyleWizard')
157
 
158
  modelName = ET.SubElement(root, 'modelName')
@@ -179,7 +262,7 @@ def create_wizard_file(data, text_mapping):
179
  eventType = 'None'
180
  if idx >= len(data['links']):
181
  continue
182
- if check_start(data['links'][idx]) and (element_type=='event' or element_type=='message'):
183
  if text_mapping[Bpmn_id] == '':
184
  text_mapping[Bpmn_id] = 'start'
185
  startEvent = ET.SubElement(root, 'startEvent', attrib={'name': text_mapping[Bpmn_id], 'eventType': eventType, 'isRegular': 'True'})
@@ -191,8 +274,7 @@ def create_wizard_file(data, text_mapping):
191
 
192
  positive_end = find_positive_end(data['BPMN_id'], data['links'], text_mapping)
193
  if positive_end is not None:
194
- print("Best end is: ",text_mapping[positive_end])
195
-
196
 
197
  # Add end states event to the collaboration element
198
  for idx, Bpmn_id in enumerate(data['BPMN_id']):
@@ -208,7 +290,6 @@ def create_wizard_file(data, text_mapping):
208
  else:
209
  ET.SubElement(endEvents, 'endState', attrib={'name': text_mapping[Bpmn_id], 'eventType': 'None', 'isRegular': 'False'})
210
 
211
-
212
  # Add activities to the collaboration element
213
  activities = ET.SubElement(root, 'activities')
214
  for idx, activity_name in enumerate(data['BPMN_id']):
@@ -269,7 +350,7 @@ def create_wizard_file(data, text_mapping):
269
  ET.SubElement(root, 'participants')
270
 
271
  # Pretty print the XML
272
- xml_str = ET.tostring(root, encoding='utf-8', method='xml')
273
- pretty_xml_str = minidom.parseString(xml_str).toprettyxml(indent=" ")
274
 
275
- return pretty_xml_str
 
4
  from modules.utils import error
5
  from modules.OCR import analyze_sentiment
6
 
 
7
  def rescale(scale, boxes):
8
+ """
9
+ Rescale the coordinates of the bounding boxes by a given scale factor.
10
+
11
+ Args:
12
+ scale (float): The scale factor to apply.
13
+ boxes (list): List of bounding boxes to be rescaled.
14
+
15
+ Returns:
16
+ list: Rescaled bounding boxes.
17
+ """
18
  for i in range(len(boxes)):
19
  boxes[i] = [boxes[i][0] * scale, boxes[i][1] * scale, boxes[i][2] * scale, boxes[i][3] * scale]
20
  return boxes
21
 
22
  def create_BPMN_id(data):
23
+ """
24
+ Create unique BPMN IDs for each element in the data based on their types.
25
+
26
+ Args:
27
+ data (dict): Dictionary containing labels and links of elements.
28
+
29
+ Returns:
30
+ dict: Updated data with BPMN IDs assigned.
31
+ """
32
  enum_end, enum_start, enum_task, enum_sequence, enum_dataflow, enum_messflow, enum_messageEvent, enum_exclusiveGateway, enum_parallelGateway, enum_pool = 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
33
  BPMN_name = [class_dict[data['labels'][i]] for i in range(len(data['labels']))]
34
  for idx, Bpmn_id in enumerate(BPMN_name):
 
67
  return data
68
 
69
  def check_end(link):
70
+ """
71
+ Check if a link represents an end event.
72
+
73
+ Args:
74
+ link (tuple): A link containing indices of connected elements.
75
+
76
+ Returns:
77
+ bool: True if the link represents an end event, False otherwise.
78
+ """
79
  if link[1] is None:
80
  return True
81
  return False
82
 
83
  def connect(data, text_mapping, i):
84
+ """
85
+ Connect elements based on their links and generate the corresponding text mapping.
86
+
87
+ Args:
88
+ data (dict): Data containing links and BPMN IDs.
89
+ text_mapping (dict): Mapping of BPMN IDs to their text descriptions.
90
+ i (int): Index of the current element.
91
+
92
+ Returns:
93
+ tuple: Current text, next texts, and next ID.
94
+ """
95
  next_text = []
96
  target_idx = data['links'][i][1]
97
  # Check if the target index is valid
98
+ if target_idx == None or target_idx >= len(data['links']):
99
  error('There may be an error with the Vizi file, care when you download it.')
100
  return None, None, None
101
 
 
118
  return current_text, next_text, next_id
119
 
120
  def check_start(val):
121
+ """
122
+ Check if a link represents a start event.
123
+
124
+ Args:
125
+ val (tuple): A link containing indices of connected elements.
126
+
127
+ Returns:
128
+ bool: True if the link represents a start event, False otherwise.
129
+ """
130
  if val[0] is None:
131
  return True
132
  return False
133
 
134
  def find_merge(bpmn_id, links):
135
+ """
136
+ Identify merge points in the BPMN diagram.
137
+
138
+ Args:
139
+ bpmn_id (list): List of BPMN IDs.
140
+ links (list): List of links between elements.
141
+
142
+ Returns:
143
+ list: List indicating merge points.
144
+ """
145
  merge = []
146
  for idx, link in enumerate(links):
147
  next_element = link[1]
 
161
  if element is None:
162
  merge_elements[idx] = False
163
  continue
164
+ # Count how many times the element is in the list
165
  count = merge.count(element)
166
  if count > 1:
167
  merge_elements[idx] = True
 
171
  return merge_elements
172
 
173
  def find_positive_end(bpmn_ids, links, text_mapping):
174
+ """
175
+ Find the positive end event based on sentiment analysis.
176
+
177
+ Args:
178
+ bpmn_ids (list): List of BPMN IDs.
179
+ links (list): List of links between elements.
180
+ text_mapping (dict): Mapping of BPMN IDs to their text descriptions.
181
+
182
+ Returns:
183
+ str: BPMN ID of the positive end event.
184
+ """
185
  emotion_data = []
186
  for idx, bpmn_id in enumerate(bpmn_ids):
187
  if idx >= len(links):
 
198
  return sorted_emotions[0][0] if len(sorted_emotions) > 0 else None
199
 
200
  def find_best_direction(texts_list):
201
+ """
202
+ Find the best direction based on sentiment analysis.
203
+
204
+ Args:
205
+ texts_list (list): List of texts to analyze.
206
+
207
+ Returns:
208
+ str: Text with the best (positive) sentiment.
209
+ """
210
  emotion_data = []
211
  for text in texts_list:
212
  highest_emotion, highest_score = analyze_sentiment(text)
 
218
 
219
  return sorted_emotions[0][0] if len(sorted_emotions) > 0 else None
220
 
 
 
221
  def create_wizard_file(data, text_mapping):
222
+ """
223
+ Create a wizard file for BPMN modeling based on the provided data and text mappings.
224
+
225
+ Args:
226
+ data (dict): Data containing BPMN elements and their properties.
227
+ text_mapping (dict): Mapping of BPMN IDs to their text descriptions.
228
 
229
+ Returns:
230
+ str: Pretty-printed XML string of the wizard file.
231
+ """
232
  not_change = ['pool','sequenceFlow','messageFlow','dataAssociation']
233
 
234
+ # Add a name into the text_mapping when there is no name
235
  for idx, key in enumerate(text_mapping.keys()):
236
  if text_mapping[key] == '' and key.split('_')[0] not in not_change:
237
  text_mapping[key] = f'unnamed_{key}'
238
 
 
239
  root = ET.Element('methodAndStyleWizard')
240
 
241
  modelName = ET.SubElement(root, 'modelName')
 
262
  eventType = 'None'
263
  if idx >= len(data['links']):
264
  continue
265
+ if check_start(data['links'][idx]) and (element_type == 'event' or element_type == 'message'):
266
  if text_mapping[Bpmn_id] == '':
267
  text_mapping[Bpmn_id] = 'start'
268
  startEvent = ET.SubElement(root, 'startEvent', attrib={'name': text_mapping[Bpmn_id], 'eventType': eventType, 'isRegular': 'True'})
 
274
 
275
  positive_end = find_positive_end(data['BPMN_id'], data['links'], text_mapping)
276
  if positive_end is not None:
277
+ print("Best end is: ", text_mapping[positive_end])
 
278
 
279
  # Add end states event to the collaboration element
280
  for idx, Bpmn_id in enumerate(data['BPMN_id']):
 
290
  else:
291
  ET.SubElement(endEvents, 'endState', attrib={'name': text_mapping[Bpmn_id], 'eventType': 'None', 'isRegular': 'False'})
292
 
 
293
  # Add activities to the collaboration element
294
  activities = ET.SubElement(root, 'activities')
295
  for idx, activity_name in enumerate(data['BPMN_id']):
 
350
  ET.SubElement(root, 'participants')
351
 
352
  # Pretty print the XML
353
+ pwm_str = ET.tostring(root, encoding='utf-8', method='xml')
354
+ pretty_pwm_str = minidom.parseString(pwm_str).toprettyxml(indent=" ")
355
 
356
+ return pretty_pwm_str
modules/toXML.py CHANGED
@@ -7,7 +7,16 @@ from xml.dom import minidom
7
  import numpy as np
8
 
9
  def find_position(pool_index, BPMN_id):
10
- #find the position of the pool_index in the bpmn_id
 
 
 
 
 
 
 
 
 
11
  if pool_index in BPMN_id:
12
  position = BPMN_id.index(pool_index)
13
  else:
@@ -18,6 +27,16 @@ def find_position(pool_index, BPMN_id):
18
 
19
  # Calculate the center of each bounding box and group them by pool
20
  def calculate_centers_and_group_by_pool(pred, class_dict):
 
 
 
 
 
 
 
 
 
 
21
  pool_groups = {}
22
  for pool_index, element_indices in pred['pool_dict'].items():
23
  pool_groups[pool_index] = []
@@ -26,12 +45,23 @@ def calculate_centers_and_group_by_pool(pred, class_dict):
26
  continue
27
  if class_dict[pred['labels'][i]] not in ['dataObject', 'dataStore']:
28
  x1, y1, x2, y2 = pred['boxes'][i]
29
- center = [(x1 + x2) / 2, (y1 + y2) / 2]
30
  pool_groups[pool_index].append((center, i))
31
  return pool_groups
32
 
33
  # Group centers within a specified range
34
  def group_centers(centers, axis, range_=50):
 
 
 
 
 
 
 
 
 
 
 
35
  groups = []
36
  while centers:
37
  center, idx = centers.pop(0)
@@ -45,18 +75,38 @@ def group_centers(centers, axis, range_=50):
45
 
46
  # Align the elements within each pool
47
  def align_elements_within_pool(modified_pred, pool_groups, class_dict, size):
 
 
 
 
 
 
 
 
 
48
  for pool_index, centers in pool_groups.items():
 
49
  y_groups = group_centers(centers.copy(), axis=1)
50
  align_y_coordinates(modified_pred, y_groups, class_dict, size)
51
 
 
52
  centers = recalculate_centers(modified_pred, y_groups)
53
  x_groups = group_centers(centers.copy(), axis=0)
54
  align_x_coordinates(modified_pred, x_groups, class_dict, size)
55
 
56
  # Align the y-coordinates of the centers of grouped bounding boxes
57
  def align_y_coordinates(modified_pred, y_groups, class_dict, size):
 
 
 
 
 
 
 
 
 
58
  for group in y_groups:
59
- avg_y = sum([c[0][1] for c in group]) / len(group)
60
  for (center, idx) in group:
61
  label = class_dict[modified_pred['labels'][idx]]
62
  if label in size:
@@ -70,18 +120,37 @@ def align_y_coordinates(modified_pred, y_groups, class_dict, size):
70
 
71
  # Recalculate centers after alignment
72
  def recalculate_centers(modified_pred, groups):
 
 
 
 
 
 
 
 
 
 
73
  centers = []
74
  for group in groups:
75
  for center, idx in group:
76
  x1, y1, x2, y2 = modified_pred['boxes'][idx]
77
- center = [(x1 + x2) / 2, (y1 + y2) / 2]
78
  centers.append((center, idx))
79
  return centers
80
 
81
  # Align the x-coordinates of the centers of grouped bounding boxes
82
  def align_x_coordinates(modified_pred, x_groups, class_dict, size):
 
 
 
 
 
 
 
 
 
83
  for group in x_groups:
84
- avg_x = sum([c[0][0] for c in group]) / len(group)
85
  for (center, idx) in group:
86
  label = class_dict[modified_pred['labels'][idx]]
87
  if label in size:
@@ -95,6 +164,13 @@ def align_x_coordinates(modified_pred, x_groups, class_dict, size):
95
 
96
  # Expand the pool bounding boxes to fit the aligned elements
97
  def expand_pool_bounding_boxes(modified_pred, size_elements):
 
 
 
 
 
 
 
98
  for idx, (pool_index, keep_elements) in enumerate(modified_pred['pool_dict'].items()):
99
  if len(keep_elements) != 0:
100
  marge = size_elements['task'][1] // 2
@@ -114,10 +190,18 @@ def expand_pool_bounding_boxes(modified_pred, size_elements):
114
  error("The pool is maybe too small, please add more elements or increase the scale by zooming on the image.")
115
  continue
116
 
 
117
  modified_pred['boxes'][position] = [min_x - marge, min_y - marge//2, min_x + pool_width + marge, min_y + pool_height + marge//2]
118
 
119
  # Adjust left and right boundaries of all pools
120
  def adjust_pool_boundaries(modified_pred, pred):
 
 
 
 
 
 
 
121
  min_left, max_right = 0, 0
122
  for pool_index, element_indices in pred['pool_dict'].items():
123
  position = find_position(pool_index, modified_pred['BPMN_id'])
@@ -140,10 +224,22 @@ def adjust_pool_boundaries(modified_pred, pred):
140
  x1 = min_left
141
  if x2 < max_right:
142
  x2 = max_right
 
143
  modified_pred['boxes'][position] = [x1, y1, x2, y2]
144
 
145
  # Main function to align boxes
146
  def align_boxes(pred, size, class_dict):
 
 
 
 
 
 
 
 
 
 
 
147
  modified_pred = copy.deepcopy(pred)
148
  pool_groups = calculate_centers_and_group_by_pool(pred, class_dict)
149
  align_elements_within_pool(modified_pred, pool_groups, class_dict, size)
@@ -154,9 +250,20 @@ def align_boxes(pred, size, class_dict):
154
 
155
  return modified_pred['boxes']
156
 
157
-
158
  # Function to create a BPMN XML file from prediction results
159
  def create_XML(full_pred, text_mapping, size_scale, scale):
 
 
 
 
 
 
 
 
 
 
 
 
160
  namespaces = {
161
  'bpmn': 'http://www.omg.org/spec/BPMN/20100524/MODEL',
162
  'bpmndi': 'http://www.omg.org/spec/BPMN/20100524/DI',
@@ -165,7 +272,6 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
165
  'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
166
  }
167
 
168
-
169
  definitions = ET.Element('bpmn:definitions', {
170
  'xmlns:xsi': namespaces['xsi'],
171
  'xmlns:bpmn': namespaces['bpmn'],
@@ -176,14 +282,13 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
176
  'id': "simpleExample"
177
  })
178
 
179
-
180
  size_elements = get_size_elements(size_scale)
181
 
182
- #if there is no pool or lane, create a pool with all elements
183
  if len(full_pred['pool_dict']) == 0 or (len(full_pred['pool_dict']) == 1 and len(next(iter(full_pred['pool_dict'].values()))) == len(full_pred['labels'])):
184
  full_pred, text_mapping = create_big_pool(full_pred, text_mapping, size_elements)
185
 
186
- #modify the boxes positions
187
  old_boxes = copy.deepcopy(full_pred)
188
 
189
  # Create BPMN collaboration element
@@ -191,16 +296,16 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
191
 
192
  # Create BPMN process elements
193
  process = []
194
- for idx in range (len(full_pred['pool_dict'].items())):
195
- process_id = f'process_{idx+1}'
196
  process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false'))
197
 
198
  bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1')
199
  bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1')
200
 
 
201
  full_pred['boxes'] = rescale_boxes(scale, old_boxes['boxes'])
202
  full_pred['boxes'] = align_boxes(full_pred, size_elements, class_dict)
203
-
204
 
205
  # Add diagram elements for each pool
206
  for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
@@ -208,8 +313,6 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
208
  pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[pool_index])
209
 
210
  position = find_position(pool_index, full_pred['BPMN_id'])
211
- # Calculate the bounding box for the pool
212
- #if len(keep_elements) == 0:
213
  if position >= len(full_pred['boxes']):
214
  print("Problem with the index")
215
  continue
@@ -219,7 +322,6 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
219
 
220
  add_diagram_elements(bpmnplane, pool_id, min_x, min_y, pool_width, pool_height)
221
 
222
-
223
  # Create BPMN elements for each pool
224
  for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
225
  create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements)
@@ -244,6 +346,7 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
244
  reparsed = minidom.parseString(rough_string)
245
  pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
246
 
 
247
  full_pred['boxes'] = rescale_boxes(1/scale, full_pred['boxes'])
248
  full_pred['boxes'] = old_boxes
249
 
@@ -251,11 +354,22 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
251
 
252
  # Function that creates a single pool with all elements
253
  def create_big_pool(full_pred, text_mapping, size_elements, marge=50):
254
- # If no pools or lanes are detected, create a single pool with all elements
 
 
 
 
 
 
 
 
 
 
 
255
  new_pool_index = 'pool_1'
256
  size_elements = get_size_elements(st.session_state.size_scale)
257
  elements_pool = list(range(len(full_pred['boxes'])))
258
- min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred['boxes'],full_pred['labels'], elements_pool, size_elements)
259
  box = [min_x - marge, min_y - marge//2, max_x + marge, max_y + marge//2]
260
  full_pred['boxes'] = np.append(full_pred['boxes'], [box], axis=0)
261
  full_pred['pool_dict'][new_pool_index] = elements_pool
@@ -266,33 +380,61 @@ def create_big_pool(full_pred, text_mapping, size_elements, marge=50):
266
 
267
  # Function that gives the size of the elements
268
  def get_size_elements(size_scale=1):
 
 
 
 
 
 
 
 
 
269
  size_elements = {
270
- 'event': (size_scale*43.2, size_scale*43.2),
271
- 'task': (size_scale*120, size_scale*96),
272
- 'message': (size_scale*43.2, size_scale*43.2),
273
- 'messageEvent': (size_scale*43.2, size_scale*43.2),
274
- 'exclusiveGateway': (size_scale*60, size_scale*60),
275
- 'parallelGateway': (size_scale*60, size_scale*60),
276
- 'dataObject': (size_scale*48, size_scale*72),
277
- 'dataStore': (size_scale*72, size_scale*72),
278
- 'subProcess': (size_scale*144, size_scale*108),
279
- 'eventBasedGateway': (size_scale*60, size_scale*60),
280
- 'timerEvent': (size_scale*48, size_scale*48),
281
  }
282
  return size_elements
283
 
284
  def rescale(scale, boxes):
 
 
 
 
 
 
 
 
 
 
285
  for i in range(len(boxes)):
286
- boxes[i] = [boxes[i][0]*scale,
287
- boxes[i][1]*scale,
288
- boxes[i][2]*scale,
289
- boxes[i][3]*scale]
290
  return boxes
291
 
292
- #Function to create the unique BPMN_id
293
- def create_BPMN_id(labels,pool_dict):
 
 
 
 
 
 
294
 
295
- BPMN_id = [class_dict[labels[i]] for i in range(len(labels))]
 
 
 
296
 
297
  data_counter = 1
298
 
@@ -336,7 +478,7 @@ def create_BPMN_id(labels,pool_dict):
336
  else:
337
  BPMN_id[idx] = f'{key}_{enums[key]}'
338
  enums[key] += 1
339
-
340
  # Update the pool_dict keys with their corresponding BPMN_id values
341
  updated_pool_dict = {}
342
  for key, value in pool_dict.items():
@@ -346,10 +488,18 @@ def create_BPMN_id(labels,pool_dict):
346
 
347
  return BPMN_id, updated_pool_dict
348
 
349
-
350
-
351
  def add_diagram_elements(parent, element_id, x, y, width, height):
352
- """Utility to add BPMN diagram notation for elements."""
 
 
 
 
 
 
 
 
 
 
353
  shape = ET.SubElement(parent, 'bpmndi:BPMNShape', attrib={
354
  'bpmnElement': element_id,
355
  'id': element_id + '_di'
@@ -362,7 +512,14 @@ def add_diagram_elements(parent, element_id, x, y, width, height):
362
  })
363
 
364
  def add_diagram_edge(parent, element_id, waypoints):
365
- """Utility to add BPMN diagram notation for sequence flows."""
 
 
 
 
 
 
 
366
  edge = ET.SubElement(parent, 'bpmndi:BPMNEdge', attrib={
367
  'bpmnElement': element_id,
368
  'id': element_id + '_di'
@@ -375,8 +532,17 @@ def add_diagram_edge(parent, element_id, waypoints):
375
  'y': str(y)
376
  })
377
 
378
-
379
  def check_status(link, keep_elements):
 
 
 
 
 
 
 
 
 
 
380
  if link[0] in keep_elements and link[1] in keep_elements:
381
  return 'middle'
382
  elif link[0] is None and link[1] in keep_elements:
@@ -385,40 +551,87 @@ def check_status(link, keep_elements):
385
  return 'end'
386
  else:
387
  return 'middle'
388
-
389
  def check_data_association(i, links, labels, keep_elements):
 
 
 
 
 
 
 
 
 
 
 
 
390
  status, links_idx = [], []
391
- for j, (k,l) in enumerate(links):
392
  if labels[j] == list(class_dict.values()).index('dataAssociation'):
393
- if k==i:
394
  status.append('output')
395
  links_idx.append(j)
396
- elif l==i:
397
  status.append('input')
398
  links_idx.append(j)
399
 
400
  return status, links_idx
401
 
402
- def create_data_Association(bpmn,data,size,element_id,current_idx,source_id,target_id):
 
 
 
 
 
 
 
 
 
 
 
 
403
  waypoints = calculate_waypoints(data, size, current_idx, source_id, target_id)
404
  if waypoints is not None:
405
  add_diagram_edge(bpmn, element_id, waypoints)
406
-
407
  def check_eventBasedGateway(i, links, labels):
 
 
 
 
 
 
 
 
 
 
 
408
  status, links_idx = [], []
409
- for j, (k,l) in enumerate(links):
410
  if labels[j] == list(class_dict.values()).index('sequenceFlow'):
411
- if k==i:
412
  status.append('output')
413
  links_idx.append(j)
414
- elif l==i:
415
  status.append('input')
416
  links_idx.append(j)
417
 
418
  return status, links_idx
419
-
420
  # Function to dynamically create and layout BPMN elements
421
  def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data, keep_elements):
 
 
 
 
 
 
 
 
 
 
 
 
422
  elements = data['BPMN_id']
423
  positions = data['boxes']
424
  links = data['links']
@@ -536,7 +749,6 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
536
  sub_element = ET.SubElement(element, 'bpmn:eventBasedGateway', id=f'eventBasedGateway_{link_idx}_{gateway_name.split("_")[1]}')
537
  create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], i, element_id, gateway_name)
538
 
539
-
540
  add_diagram_elements(bpmnplane, element_id, x, y, size['eventBasedGateway'][0], size['eventBasedGateway'][1])
541
 
542
  # Data Object
@@ -558,6 +770,19 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
558
  add_diagram_elements(bpmnplane, element_id, x, y, size['timerEvent'][0], size['timerEvent'][1])
559
 
560
  def calculate_pool_bounds(boxes, labels, keep_elements, size=None, class_dict=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  min_x, min_y = float('inf'), float('inf')
562
  max_x, max_y = float('-inf'), float('-inf')
563
 
@@ -588,9 +813,22 @@ def calculate_pool_bounds(boxes, labels, keep_elements, size=None, class_dict=No
588
 
589
  return min_x, min_y, max_x, max_y
590
 
591
-
592
-
593
  def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_element, target_element):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
  # Get the bounding boxes of the source and target elements
595
  source_box = data['boxes'][source_idx]
596
  target_box = data['boxes'][target_idx]
@@ -625,11 +863,19 @@ def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_ele
625
  waypoints = [(element_mid_x, element_box[3]), (element_mid_x, pool_box[1])]
626
 
627
  return waypoints
628
-
629
  def add_curve(waypoints, pos_source, pos_target, threshold=30):
630
  """
631
  Add a single curve to the sequence flow by introducing a control point.
632
  The control point is added at an offset from the midpoint of the original waypoints.
 
 
 
 
 
 
 
 
 
633
  """
634
  if len(waypoints) < 2:
635
  return waypoints
@@ -647,7 +893,7 @@ def add_curve(waypoints, pos_source, pos_target, threshold=30):
647
  if abs(start_x - end_x) < threshold or abs(start_y - end_y) < threshold:
648
  return waypoints
649
 
650
- # Calculate the control point
651
  if pos_source in pos_horizontal and pos_target in pos_horizontal:
652
  control_point = None
653
  elif pos_source in pos_vertical and pos_target in pos_vertical:
@@ -658,7 +904,6 @@ def add_curve(waypoints, pos_source, pos_target, threshold=30):
658
  control_point = (start_x, end_y)
659
  else:
660
  control_point = None
661
-
662
 
663
  # Create the curved path
664
  if control_point is not None:
@@ -668,8 +913,20 @@ def add_curve(waypoints, pos_source, pos_target, threshold=30):
668
 
669
  return curved_waypoints
670
 
671
-
672
  def calculate_waypoints(data, size, current_idx, source_id, target_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
673
  best_points = data['best_points'][current_idx]
674
  pos_source = best_points[0]
675
  pos_target = best_points[1]
@@ -684,7 +941,6 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
684
  if source_idx is None or target_idx is None:
685
  warning()
686
  return None
687
-
688
 
689
  name_source = source_id.split('_')[0]
690
  name_target = target_id.split('_')[0]
@@ -702,6 +958,7 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
702
  warning()
703
  return [(source_x, source_y), (target_x, target_y)]
704
 
 
705
  if pos_source == 'left':
706
  source_x = source_x
707
  source_y += size[name_source][1] / 2
@@ -715,6 +972,7 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
715
  source_x += size[name_source][0] / 2
716
  source_y += size[name_source][1]
717
 
 
718
  if pos_target == 'left':
719
  target_x = target_x
720
  target_y += size[name_target][1] / 2
@@ -738,8 +996,19 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
738
 
739
  return curved_waypoints
740
 
741
-
742
  def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=False):
 
 
 
 
 
 
 
 
 
 
 
 
743
  source_idx, target_idx = data['links'][idx]
744
 
745
  if source_idx is None or target_idx is None:
@@ -774,6 +1043,3 @@ def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=Fal
774
  return
775
  element = ET.SubElement(parent, 'bpmn:sequenceFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
776
  add_diagram_edge(bpmn, element_id, waypoints)
777
-
778
-
779
-
 
7
  import numpy as np
8
 
9
  def find_position(pool_index, BPMN_id):
10
+ """
11
+ Find the position of the pool index in the BPMN_id list.
12
+
13
+ Args:
14
+ pool_index (str): The pool index to search for.
15
+ BPMN_id (list): List of BPMN IDs.
16
+
17
+ Returns:
18
+ int: The index of the pool_index in BPMN_id, or None if not found.
19
+ """
20
  if pool_index in BPMN_id:
21
  position = BPMN_id.index(pool_index)
22
  else:
 
27
 
28
  # Calculate the center of each bounding box and group them by pool
29
  def calculate_centers_and_group_by_pool(pred, class_dict):
30
+ """
31
+ Calculate the center coordinates of bounding boxes and group them by pool.
32
+
33
+ Args:
34
+ pred (dict): Dictionary containing prediction results, including 'pool_dict', 'boxes', and 'labels'.
35
+ class_dict (dict): Dictionary mapping class indices to class names.
36
+
37
+ Returns:
38
+ dict: Dictionary grouping centers and their indices by pool index.
39
+ """
40
  pool_groups = {}
41
  for pool_index, element_indices in pred['pool_dict'].items():
42
  pool_groups[pool_index] = []
 
45
  continue
46
  if class_dict[pred['labels'][i]] not in ['dataObject', 'dataStore']:
47
  x1, y1, x2, y2 = pred['boxes'][i]
48
+ center = [(x1 + x2) / 2, (y1 + y2) / 2] # Compute the center of the bounding box
49
  pool_groups[pool_index].append((center, i))
50
  return pool_groups
51
 
52
  # Group centers within a specified range
53
  def group_centers(centers, axis, range_=50):
54
+ """
55
+ Group centers based on a specified range along an axis.
56
+
57
+ Args:
58
+ centers (list): List of center coordinates and their indices.
59
+ axis (int): The axis (0 for x, 1 for y) to group centers along.
60
+ range_ (int): Maximum distance to consider centers as part of the same group.
61
+
62
+ Returns:
63
+ list: List of groups, where each group is a list of centers and indices.
64
+ """
65
  groups = []
66
  while centers:
67
  center, idx = centers.pop(0)
 
75
 
76
  # Align the elements within each pool
77
  def align_elements_within_pool(modified_pred, pool_groups, class_dict, size):
78
+ """
79
+ Align elements within each pool based on their centers.
80
+
81
+ Args:
82
+ modified_pred (dict): Dictionary containing the modified predictions.
83
+ pool_groups (dict): Dictionary grouping centers and their indices by pool index.
84
+ class_dict (dict): Dictionary mapping class indices to class names.
85
+ size (dict): Dictionary containing element sizes.
86
+ """
87
  for pool_index, centers in pool_groups.items():
88
+ # Align elements based on y-coordinates
89
  y_groups = group_centers(centers.copy(), axis=1)
90
  align_y_coordinates(modified_pred, y_groups, class_dict, size)
91
 
92
+ # Recalculate centers after y-alignment and then align based on x-coordinates
93
  centers = recalculate_centers(modified_pred, y_groups)
94
  x_groups = group_centers(centers.copy(), axis=0)
95
  align_x_coordinates(modified_pred, x_groups, class_dict, size)
96
 
97
  # Align the y-coordinates of the centers of grouped bounding boxes
98
  def align_y_coordinates(modified_pred, y_groups, class_dict, size):
99
+ """
100
+ Align the y-coordinates of elements in each group.
101
+
102
+ Args:
103
+ modified_pred (dict): Dictionary containing the modified predictions.
104
+ y_groups (list): List of groups of centers and their indices, grouped by y-coordinate.
105
+ class_dict (dict): Dictionary mapping class indices to class names.
106
+ size (dict): Dictionary containing element sizes.
107
+ """
108
  for group in y_groups:
109
+ avg_y = sum([c[0][1] for c in group]) / len(group) # Compute the average y-coordinate
110
  for (center, idx) in group:
111
  label = class_dict[modified_pred['labels'][idx]]
112
  if label in size:
 
120
 
121
  # Recalculate centers after alignment
122
  def recalculate_centers(modified_pred, groups):
123
+ """
124
+ Recalculate the centers of bounding boxes after alignment.
125
+
126
+ Args:
127
+ modified_pred (dict): Dictionary containing the modified predictions.
128
+ groups (list): List of groups of centers and their indices.
129
+
130
+ Returns:
131
+ list: List of recalculated centers and their indices.
132
+ """
133
  centers = []
134
  for group in groups:
135
  for center, idx in group:
136
  x1, y1, x2, y2 = modified_pred['boxes'][idx]
137
+ center = [(x1 + x2) / 2, (y1 + y2) / 2] # Recompute the center after alignment
138
  centers.append((center, idx))
139
  return centers
140
 
141
  # Align the x-coordinates of the centers of grouped bounding boxes
142
  def align_x_coordinates(modified_pred, x_groups, class_dict, size):
143
+ """
144
+ Align the x-coordinates of elements in each group.
145
+
146
+ Args:
147
+ modified_pred (dict): Dictionary containing the modified predictions.
148
+ x_groups (list): List of groups of centers and their indices, grouped by x-coordinate.
149
+ class_dict (dict): Dictionary mapping class indices to class names.
150
+ size (dict): Dictionary containing element sizes.
151
+ """
152
  for group in x_groups:
153
+ avg_x = sum([c[0][0] for c in group]) / len(group) # Compute the average x-coordinate
154
  for (center, idx) in group:
155
  label = class_dict[modified_pred['labels'][idx]]
156
  if label in size:
 
164
 
165
  # Expand the pool bounding boxes to fit the aligned elements
166
  def expand_pool_bounding_boxes(modified_pred, size_elements):
167
+ """
168
+ Expand the bounding boxes of pools to fit aligned elements.
169
+
170
+ Args:
171
+ modified_pred (dict): Dictionary containing the modified predictions.
172
+ size_elements (dict): Dictionary containing element sizes.
173
+ """
174
  for idx, (pool_index, keep_elements) in enumerate(modified_pred['pool_dict'].items()):
175
  if len(keep_elements) != 0:
176
  marge = size_elements['task'][1] // 2
 
190
  error("The pool is maybe too small, please add more elements or increase the scale by zooming on the image.")
191
  continue
192
 
193
+ # Update the pool bounding box with margin
194
  modified_pred['boxes'][position] = [min_x - marge, min_y - marge//2, min_x + pool_width + marge, min_y + pool_height + marge//2]
195
 
196
  # Adjust left and right boundaries of all pools
197
  def adjust_pool_boundaries(modified_pred, pred):
198
+ """
199
+ Adjust the left and right boundaries of all pools to ensure they cover all elements.
200
+
201
+ Args:
202
+ modified_pred (dict): Dictionary containing the modified predictions.
203
+ pred (dict): Dictionary containing original prediction results.
204
+ """
205
  min_left, max_right = 0, 0
206
  for pool_index, element_indices in pred['pool_dict'].items():
207
  position = find_position(pool_index, modified_pred['BPMN_id'])
 
224
  x1 = min_left
225
  if x2 < max_right:
226
  x2 = max_right
227
+ # Update the pool bounding box with adjusted boundaries
228
  modified_pred['boxes'][position] = [x1, y1, x2, y2]
229
 
230
  # Main function to align boxes
231
  def align_boxes(pred, size, class_dict):
232
+ """
233
+ Main function to align bounding boxes for the given prediction data.
234
+
235
+ Args:
236
+ pred (dict): Dictionary containing prediction results.
237
+ size (dict): Dictionary containing element sizes.
238
+ class_dict (dict): Dictionary mapping class indices to class names.
239
+
240
+ Returns:
241
+ list: List of aligned bounding boxes.
242
+ """
243
  modified_pred = copy.deepcopy(pred)
244
  pool_groups = calculate_centers_and_group_by_pool(pred, class_dict)
245
  align_elements_within_pool(modified_pred, pool_groups, class_dict, size)
 
250
 
251
  return modified_pred['boxes']
252
 
 
253
  # Function to create a BPMN XML file from prediction results
254
  def create_XML(full_pred, text_mapping, size_scale, scale):
255
+ """
256
+ Create a BPMN XML file from the prediction results.
257
+
258
+ Args:
259
+ full_pred (dict): Dictionary containing full prediction results.
260
+ text_mapping (dict): Dictionary mapping BPMN IDs to text labels.
261
+ size_scale (float): Scaling factor for element sizes.
262
+ scale (float): Scaling factor for bounding boxes.
263
+
264
+ Returns:
265
+ str: Pretty-printed BPMN XML string.
266
+ """
267
  namespaces = {
268
  'bpmn': 'http://www.omg.org/spec/BPMN/20100524/MODEL',
269
  'bpmndi': 'http://www.omg.org/spec/BPMN/20100524/DI',
 
272
  'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
273
  }
274
 
 
275
  definitions = ET.Element('bpmn:definitions', {
276
  'xmlns:xsi': namespaces['xsi'],
277
  'xmlns:bpmn': namespaces['bpmn'],
 
282
  'id': "simpleExample"
283
  })
284
 
 
285
  size_elements = get_size_elements(size_scale)
286
 
287
+ # If there is no pool or lane, create a pool with all elements
288
  if len(full_pred['pool_dict']) == 0 or (len(full_pred['pool_dict']) == 1 and len(next(iter(full_pred['pool_dict'].values()))) == len(full_pred['labels'])):
289
  full_pred, text_mapping = create_big_pool(full_pred, text_mapping, size_elements)
290
 
291
+ # Backup the original box positions
292
  old_boxes = copy.deepcopy(full_pred)
293
 
294
  # Create BPMN collaboration element
 
296
 
297
  # Create BPMN process elements
298
  process = []
299
+ for idx in range(len(full_pred['pool_dict'].items())):
300
+ process_id = f'process_{idx+1}'
301
  process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false'))
302
 
303
  bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1')
304
  bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1')
305
 
306
+ # Rescale and align bounding boxes
307
  full_pred['boxes'] = rescale_boxes(scale, old_boxes['boxes'])
308
  full_pred['boxes'] = align_boxes(full_pred, size_elements, class_dict)
 
309
 
310
  # Add diagram elements for each pool
311
  for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
 
313
  pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[pool_index])
314
 
315
  position = find_position(pool_index, full_pred['BPMN_id'])
 
 
316
  if position >= len(full_pred['boxes']):
317
  print("Problem with the index")
318
  continue
 
322
 
323
  add_diagram_elements(bpmnplane, pool_id, min_x, min_y, pool_width, pool_height)
324
 
 
325
  # Create BPMN elements for each pool
326
  for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
327
  create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements)
 
346
  reparsed = minidom.parseString(rough_string)
347
  pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
348
 
349
+ # Restore the original box positions
350
  full_pred['boxes'] = rescale_boxes(1/scale, full_pred['boxes'])
351
  full_pred['boxes'] = old_boxes
352
 
 
354
 
355
  # Function that creates a single pool with all elements
356
  def create_big_pool(full_pred, text_mapping, size_elements, marge=50):
357
+ """
358
+ Create a single pool containing all elements if no pools or lanes are detected.
359
+
360
+ Args:
361
+ full_pred (dict): Dictionary containing full prediction results.
362
+ text_mapping (dict): Dictionary mapping BPMN IDs to text labels.
363
+ size_elements (dict): Dictionary containing element sizes.
364
+ marge (int, optional): Margin to add around the pool. Defaults to 50.
365
+
366
+ Returns:
367
+ tuple: Updated full_pred and text_mapping.
368
+ """
369
  new_pool_index = 'pool_1'
370
  size_elements = get_size_elements(st.session_state.size_scale)
371
  elements_pool = list(range(len(full_pred['boxes'])))
372
+ min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred['boxes'], full_pred['labels'], elements_pool, size_elements)
373
  box = [min_x - marge, min_y - marge//2, max_x + marge, max_y + marge//2]
374
  full_pred['boxes'] = np.append(full_pred['boxes'], [box], axis=0)
375
  full_pred['pool_dict'][new_pool_index] = elements_pool
 
380
 
381
  # Function that gives the size of the elements
382
  def get_size_elements(size_scale=1):
383
+ """
384
+ Get the sizes of BPMN elements based on the scaling factor.
385
+
386
+ Args:
387
+ size_scale (float, optional): Scaling factor for element sizes. Defaults to 1.
388
+
389
+ Returns:
390
+ dict: Dictionary containing element sizes.
391
+ """
392
  size_elements = {
393
+ 'event': (size_scale * 43.2, size_scale * 43.2),
394
+ 'task': (size_scale * 120, size_scale * 96),
395
+ 'message': (size_scale * 43.2, size_scale * 43.2),
396
+ 'messageEvent': (size_scale * 43.2, size_scale * 43.2),
397
+ 'exclusiveGateway': (size_scale * 60, size_scale * 60),
398
+ 'parallelGateway': (size_scale * 60, size_scale * 60),
399
+ 'dataObject': (size_scale * 48, size_scale * 72),
400
+ 'dataStore': (size_scale * 72, size_scale * 72),
401
+ 'subProcess': (size_scale * 144, size_scale * 108),
402
+ 'eventBasedGateway': (size_scale * 60, size_scale * 60),
403
+ 'timerEvent': (size_scale * 48, size_scale * 48),
404
  }
405
  return size_elements
406
 
407
  def rescale(scale, boxes):
408
+ """
409
+ Rescale the bounding boxes by a given scaling factor.
410
+
411
+ Args:
412
+ scale (float): Scaling factor.
413
+ boxes (list): List of bounding boxes.
414
+
415
+ Returns:
416
+ list: Rescaled bounding boxes.
417
+ """
418
  for i in range(len(boxes)):
419
+ boxes[i] = [boxes[i][0] * scale,
420
+ boxes[i][1] * scale,
421
+ boxes[i][2] * scale,
422
+ boxes[i][3] * scale]
423
  return boxes
424
 
425
+ # Function to create the unique BPMN_id
426
+ def create_BPMN_id(labels, pool_dict):
427
+ """
428
+ Create unique BPMN IDs for each element based on their labels.
429
+
430
+ Args:
431
+ labels (list): List of labels for each element.
432
+ pool_dict (dict): Dictionary containing pool indices and their elements.
433
 
434
+ Returns:
435
+ tuple: List of BPMN IDs and updated pool dictionary.
436
+ """
437
+ BPMN_id = [class_dict[labels[i]] for i in range(len(labels))]
438
 
439
  data_counter = 1
440
 
 
478
  else:
479
  BPMN_id[idx] = f'{key}_{enums[key]}'
480
  enums[key] += 1
481
+
482
  # Update the pool_dict keys with their corresponding BPMN_id values
483
  updated_pool_dict = {}
484
  for key, value in pool_dict.items():
 
488
 
489
  return BPMN_id, updated_pool_dict
490
 
 
 
491
  def add_diagram_elements(parent, element_id, x, y, width, height):
492
+ """
493
+ Utility to add BPMN diagram notation for elements.
494
+
495
+ Args:
496
+ parent (Element): The parent XML element.
497
+ element_id (str): The ID of the BPMN element.
498
+ x (float): The x-coordinate of the element.
499
+ y (float): The y-coordinate of the element.
500
+ width (float): The width of the element.
501
+ height (float): The height of the element.
502
+ """
503
  shape = ET.SubElement(parent, 'bpmndi:BPMNShape', attrib={
504
  'bpmnElement': element_id,
505
  'id': element_id + '_di'
 
512
  })
513
 
514
  def add_diagram_edge(parent, element_id, waypoints):
515
+ """
516
+ Utility to add BPMN diagram notation for sequence flows.
517
+
518
+ Args:
519
+ parent (Element): The parent XML element.
520
+ element_id (str): The ID of the BPMN element.
521
+ waypoints (list): List of waypoints for the sequence flow.
522
+ """
523
  edge = ET.SubElement(parent, 'bpmndi:BPMNEdge', attrib={
524
  'bpmnElement': element_id,
525
  'id': element_id + '_di'
 
532
  'y': str(y)
533
  })
534
 
 
535
  def check_status(link, keep_elements):
536
+ """
537
+ Check the status of a link in terms of its position within the elements.
538
+
539
+ Args:
540
+ link (tuple): A tuple representing the start and end of the link.
541
+ keep_elements (list): List of elements to keep.
542
+
543
+ Returns:
544
+ str: Status of the link ('middle', 'start', or 'end').
545
+ """
546
  if link[0] in keep_elements and link[1] in keep_elements:
547
  return 'middle'
548
  elif link[0] is None and link[1] in keep_elements:
 
551
  return 'end'
552
  else:
553
  return 'middle'
554
+
555
  def check_data_association(i, links, labels, keep_elements):
556
+ """
557
+ Check data associations for an element.
558
+
559
+ Args:
560
+ i (int): Index of the current element.
561
+ links (list): List of links between elements.
562
+ labels (list): List of labels for each element.
563
+ keep_elements (list): List of elements to keep.
564
+
565
+ Returns:
566
+ tuple: Status and indices of data associations.
567
+ """
568
  status, links_idx = [], []
569
+ for j, (k, l) in enumerate(links):
570
  if labels[j] == list(class_dict.values()).index('dataAssociation'):
571
+ if k == i:
572
  status.append('output')
573
  links_idx.append(j)
574
+ elif l == i:
575
  status.append('input')
576
  links_idx.append(j)
577
 
578
  return status, links_idx
579
 
580
+ def create_data_Association(bpmn, data, size, element_id, current_idx, source_id, target_id):
581
+ """
582
+ Create a data association in the BPMN diagram.
583
+
584
+ Args:
585
+ bpmn (Element): The parent XML element.
586
+ data (dict): Dictionary containing prediction results.
587
+ size (dict): Dictionary containing element sizes.
588
+ element_id (str): The ID of the BPMN element.
589
+ current_idx (int): Index of the current element.
590
+ source_id (str): The source element ID.
591
+ target_id (str): The target element ID.
592
+ """
593
  waypoints = calculate_waypoints(data, size, current_idx, source_id, target_id)
594
  if waypoints is not None:
595
  add_diagram_edge(bpmn, element_id, waypoints)
596
+
597
  def check_eventBasedGateway(i, links, labels):
598
+ """
599
+ Check event-based gateway for an element.
600
+
601
+ Args:
602
+ i (int): Index of the current element.
603
+ links (list): List of links between elements.
604
+ labels (list): List of labels for each element.
605
+
606
+ Returns:
607
+ tuple: Status and indices of event-based gateway.
608
+ """
609
  status, links_idx = [], []
610
+ for j, (k, l) in enumerate(links):
611
  if labels[j] == list(class_dict.values()).index('sequenceFlow'):
612
+ if k == i:
613
  status.append('output')
614
  links_idx.append(j)
615
+ elif l == i:
616
  status.append('input')
617
  links_idx.append(j)
618
 
619
  return status, links_idx
620
+
621
  # Function to dynamically create and layout BPMN elements
622
  def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data, keep_elements):
623
+ """
624
+ Dynamically create and layout BPMN elements.
625
+
626
+ Args:
627
+ process (Element): The BPMN process element.
628
+ bpmnplane (Element): The BPMN plane element.
629
+ text_mapping (dict): Dictionary mapping BPMN IDs to text labels.
630
+ definitions (Element): The BPMN definitions element.
631
+ size (dict): Dictionary containing element sizes.
632
+ data (dict): Dictionary containing prediction results.
633
+ keep_elements (list): List of elements to keep.
634
+ """
635
  elements = data['BPMN_id']
636
  positions = data['boxes']
637
  links = data['links']
 
749
  sub_element = ET.SubElement(element, 'bpmn:eventBasedGateway', id=f'eventBasedGateway_{link_idx}_{gateway_name.split("_")[1]}')
750
  create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], i, element_id, gateway_name)
751
 
 
752
  add_diagram_elements(bpmnplane, element_id, x, y, size['eventBasedGateway'][0], size['eventBasedGateway'][1])
753
 
754
  # Data Object
 
770
  add_diagram_elements(bpmnplane, element_id, x, y, size['timerEvent'][0], size['timerEvent'][1])
771
 
772
  def calculate_pool_bounds(boxes, labels, keep_elements, size=None, class_dict=None):
773
+ """
774
+ Calculate the bounding box for a pool.
775
+
776
+ Args:
777
+ boxes (list): List of bounding boxes.
778
+ labels (list): List of labels for each element.
779
+ keep_elements (list): List of elements to keep.
780
+ size (dict, optional): Dictionary containing element sizes. Defaults to None.
781
+ class_dict (dict, optional): Dictionary mapping class indices to class names. Defaults to None.
782
+
783
+ Returns:
784
+ tuple: Minimum and maximum x and y coordinates of the pool.
785
+ """
786
  min_x, min_y = float('inf'), float('inf')
787
  max_x, max_y = float('-inf'), float('-inf')
788
 
 
813
 
814
  return min_x, min_y, max_x, max_y
815
 
 
 
816
  def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_element, target_element):
817
+ """
818
+ Calculate waypoints for connecting elements within a pool.
819
+
820
+ Args:
821
+ idx (int): Index of the current element.
822
+ data (dict): Dictionary containing prediction results.
823
+ size (dict): Dictionary containing element sizes.
824
+ source_idx (int): Index of the source element.
825
+ target_idx (int): Index of the target element.
826
+ source_element (str): Source element type.
827
+ target_element (str): Target element type.
828
+
829
+ Returns:
830
+ list: List of waypoints for the connection.
831
+ """
832
  # Get the bounding boxes of the source and target elements
833
  source_box = data['boxes'][source_idx]
834
  target_box = data['boxes'][target_idx]
 
863
  waypoints = [(element_mid_x, element_box[3]), (element_mid_x, pool_box[1])]
864
 
865
  return waypoints
 
866
  def add_curve(waypoints, pos_source, pos_target, threshold=30):
867
  """
868
  Add a single curve to the sequence flow by introducing a control point.
869
  The control point is added at an offset from the midpoint of the original waypoints.
870
+
871
+ Args:
872
+ waypoints (list): List of waypoints representing the path.
873
+ pos_source (str): Position of the source element ('left', 'right', 'top', 'bottom').
874
+ pos_target (str): Position of the target element ('left', 'right', 'top', 'bottom').
875
+ threshold (int, optional): Minimum distance to consider for adding a curve. Defaults to 30.
876
+
877
+ Returns:
878
+ list: List of waypoints with the added control point if applicable.
879
  """
880
  if len(waypoints) < 2:
881
  return waypoints
 
893
  if abs(start_x - end_x) < threshold or abs(start_y - end_y) < threshold:
894
  return waypoints
895
 
896
+ # Calculate the control point based on source and target positions
897
  if pos_source in pos_horizontal and pos_target in pos_horizontal:
898
  control_point = None
899
  elif pos_source in pos_vertical and pos_target in pos_vertical:
 
904
  control_point = (start_x, end_y)
905
  else:
906
  control_point = None
 
907
 
908
  # Create the curved path
909
  if control_point is not None:
 
913
 
914
  return curved_waypoints
915
 
 
916
  def calculate_waypoints(data, size, current_idx, source_id, target_id):
917
+ """
918
+ Calculate waypoints for connecting two elements in the diagram.
919
+
920
+ Args:
921
+ data (dict): Data containing diagram information.
922
+ size (dict): Dictionary of element sizes.
923
+ current_idx (int): Index of the current element.
924
+ source_id (str): ID of the source element.
925
+ target_id (str): ID of the target element.
926
+
927
+ Returns:
928
+ list: List of waypoints for the connection.
929
+ """
930
  best_points = data['best_points'][current_idx]
931
  pos_source = best_points[0]
932
  pos_target = best_points[1]
 
941
  if source_idx is None or target_idx is None:
942
  warning()
943
  return None
 
944
 
945
  name_source = source_id.split('_')[0]
946
  name_target = target_id.split('_')[0]
 
958
  warning()
959
  return [(source_x, source_y), (target_x, target_y)]
960
 
961
+ # Adjust the source coordinates based on its position
962
  if pos_source == 'left':
963
  source_x = source_x
964
  source_y += size[name_source][1] / 2
 
972
  source_x += size[name_source][0] / 2
973
  source_y += size[name_source][1]
974
 
975
+ # Adjust the target coordinates based on its position
976
  if pos_target == 'left':
977
  target_x = target_x
978
  target_y += size[name_target][1] / 2
 
996
 
997
  return curved_waypoints
998
 
 
999
  def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=False):
1000
+ """
1001
+ Create a BPMN flow element (sequence flow or message flow) and add it to the BPMN diagram.
1002
+
1003
+ Args:
1004
+ bpmn (ET.Element): The BPMN diagram element.
1005
+ text_mapping (dict): Dictionary mapping element IDs to their text labels.
1006
+ idx (int): Index of the current element.
1007
+ size (dict): Dictionary of element sizes.
1008
+ data (dict): Data containing diagram information.
1009
+ parent (ET.Element): The parent element to which the flow element is added.
1010
+ message (bool, optional): Whether the flow is a message flow. Defaults to False.
1011
+ """
1012
  source_idx, target_idx = data['links'][idx]
1013
 
1014
  if source_idx is None or target_idx is None:
 
1043
  return
1044
  element = ET.SubElement(parent, 'bpmn:sequenceFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
1045
  add_diagram_edge(bpmn, element_id, waypoints)
 
 
 
modules/train.py CHANGED
@@ -15,8 +15,6 @@ from tqdm import tqdm
15
  from modules.utils import write_results
16
 
17
 
18
-
19
-
20
  def get_arrow_model(num_classes, num_keypoints=2):
21
  """
22
  Configures and returns a modified Keypoint R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes and keypoints.
@@ -27,14 +25,6 @@ def get_arrow_model(num_classes, num_keypoints=2):
27
 
28
  Returns:
29
  - model (torch.nn.Module): The modified Keypoint R-CNN model.
30
-
31
- Steps:
32
- 1. Load a pre-trained Keypoint R-CNN model with a ResNet-50 backbone and Feature Pyramid Network (FPN).
33
- The model is initially configured for the COCO dataset, which includes various object classes and keypoints.
34
- 2. Replace the box predictor to adjust the number of output classes. The box predictor is responsible for
35
- classifying detected regions and predicting their bounding boxes.
36
- 3. Replace the keypoint predictor to adjust the number of keypoints the model predicts for each object.
37
- This is necessary to tailor the model to specific tasks that may have different keypoint structures.
38
  """
39
  # Load a model pre-trained on COCO, initialized without pre-trained weights
40
  model = keypointrcnn_resnet50_fpn(weights=None)
@@ -72,44 +62,60 @@ def get_faster_rcnn_model(num_classes):
72
 
73
  return model
74
 
75
- def prepare_model(dict,opti,learning_rate= 0.0003,model_to_load=None, model_type = 'object'):
76
- # Adjusted to pass the class_dict directly
77
- if model_type == 'object':
78
- model = get_faster_rcnn_model(len(dict))
79
- elif model_type == 'arrow':
80
- model = get_arrow_model(len(dict),2)
81
 
82
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
83
- # Load the model weights
84
- if model_to_load:
85
- model.load_state_dict(torch.load('./models/'+ model_to_load +'.pth', map_location=device))
86
- print(f"Model '{model_to_load}' loaded")
 
87
 
88
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
89
- model.to(device)
 
 
 
 
 
 
 
 
90
 
91
- if opti == 'SGD':
92
- #learning_rate= 0.002
93
- optimizer = SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)
94
- elif opti == 'Adam':
95
- #learning_rate = 0.0003
96
- optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.00056, eps=1e-08, betas=(0.9, 0.999))
97
- else:
98
- print('Optimizer not found')
99
 
100
- return model, optimizer, device
101
 
 
 
 
 
 
 
102
 
103
- import copy
104
- from torch.optim import AdamW
105
- import time
106
- from modules.train import write_results
107
 
108
- import torch
109
- import numpy as np
110
- from tqdm import tqdm
111
 
112
  def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  model.train() # Set the model to evaluation mode
114
  total_loss = 0
115
 
@@ -174,12 +180,12 @@ def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=Fal
174
  avg_loss_keypoints = np.mean(loss_keypoints_list)
175
 
176
  if print_losses:
177
- print(f"Average Loss: {avg_loss:.4f}")
178
- print(f"Average Classifier Loss: {avg_loss_classifier:.4f}")
179
- print(f"Average Box Regression Loss: {avg_loss_box_reg:.4f}")
180
- print(f"Average Objectness Loss: {avg_loss_objectness:.4f}")
181
- print(f"Average RPN Box Regression Loss: {avg_loss_rpn_box_reg:.4f}")
182
- print(f"Average Keypoints Loss: {avg_loss_keypoints:.4f}")
183
 
184
  return avg_loss
185
 
@@ -188,206 +194,225 @@ def training_model(num_epochs, model, data_loader, subset_test_loader,
188
  optimizer, model_to_load=None, change_learning_rate=100, start_key=100,
189
  parameters=None, blur_prob=0.02,
190
  score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97,
191
- information_training='training', start_epoch=0, loss_config=None, model_type = 'object',
192
  eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- # Set the model to training mode
195
- model.train()
196
-
197
- if loss_config is None:
198
- print('No loss config found, all losses will be used.')
199
- else:
200
- #print the list of the losses that will be used
201
- print('The following losses will be used: ', end='')
202
- for key, value in loss_config.items():
203
- if value:
204
- print(key, end=", ")
205
- print()
206
-
207
-
208
- # Initialize lists to store epoch-wise average losses
209
- epoch_avg_losses = []
210
- epoch_avg_loss_classifier = []
211
- epoch_avg_loss_box_reg = []
212
- epoch_avg_loss_objectness = []
213
- epoch_avg_loss_rpn_box_reg = []
214
- epoch_avg_loss_keypoints = []
215
- epoch_precision = []
216
- epoch_recall = []
217
- epoch_f1_score = []
218
- epoch_test_loss = []
219
-
220
-
221
- start_tot = time.time()
222
- best_metrics = -1000
223
- best_epoch = 0
224
- best_model_state = None
225
- same = 0
226
- learning_rate = optimizer.param_groups[0]['lr']
227
- bad_test_loss = 0
228
- previous_test_loss = 1000
229
-
230
- if parameters is not None:
231
- batch_size, crop_prob, rotate_90_proba, h_flip_prob, v_flip_prob, max_rotate_deg, rotate_proba, keep_ratio = parameters.values()
232
-
233
- print(f"Let's go training {model_type} model with {num_epochs} epochs!")
234
- if parameters is not None:
235
- print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, H flip prob: {h_flip_prob}, V flip prob: {v_flip_prob}, Max rotate deg: {max_rotate_deg}, Rotate proba: {rotate_proba}, Rotate 90 proba: {rotate_90_proba}, Keep ratio: {keep_ratio}")
236
-
237
- for epoch in range(num_epochs):
238
-
239
- if (epoch>0 and (epoch)%change_learning_rate == 0) or bad_test_loss>=3:
240
- learning_rate = 0.7*learning_rate
241
- optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999))
242
- if best_model_state is not None:
243
- model.load_state_dict(best_model_state)
244
- print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}')
245
- bad_test_loss = 0
246
- if epoch>0 and (epoch)==start_key:
247
- print("Now it's training Keypoints also")
248
- loss_config['loss_keypoint'] = True
249
- for name, param in model.named_parameters():
250
- if 'keypoint' in name:
251
- param.requires_grad = True
252
-
253
- model.train()
254
- start = time.time()
255
- total_loss = 0
256
-
257
- # Initialize lists to keep track of individual losses
258
- loss_classifier_list = []
259
- loss_box_reg_list = []
260
- loss_objectness_list = []
261
- loss_rpn_box_reg_list = []
262
- loss_keypoints_list = []
263
-
264
- # Create a tqdm progress bar
265
- progress_bar = tqdm(data_loader, desc=f'Epoch {epoch+1+start_epoch}')
266
-
267
- for images, targets_im in progress_bar:
268
- images = [image.to(device) for image in images]
269
- targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im]
270
-
271
- optimizer.zero_grad()
272
-
273
- loss_dict = model(images, targets)
274
- # Inside the training loop where losses are calculated:
275
- losses = 0
276
- if loss_config is not None:
277
- for key, loss in loss_dict.items():
278
- if loss_config.get(key, False):
279
- if key == 'loss_classifier':
280
- loss *= 3
281
- losses += loss
282
- else:
283
- losses = sum(loss for key, loss in loss_dict.items())
284
-
285
- # Collect individual losses
286
- if loss_dict['loss_classifier']:
287
- loss_classifier_list.append(loss_dict['loss_classifier'].item())
288
- else:
289
- loss_classifier_list.append(0)
290
-
291
- if loss_dict['loss_box_reg']:
292
- loss_box_reg_list.append(loss_dict['loss_box_reg'].item())
293
- else:
294
- loss_box_reg_list.append(0)
295
-
296
- if loss_dict['loss_objectness']:
297
- loss_objectness_list.append(loss_dict['loss_objectness'].item())
298
- else:
299
- loss_objectness_list.append(0)
300
-
301
- if loss_dict['loss_rpn_box_reg']:
302
- loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item())
303
- else:
304
- loss_rpn_box_reg_list.append(0)
305
-
306
- if 'loss_keypoint' in loss_dict:
307
- loss_keypoints_list.append(loss_dict['loss_keypoint'].item())
308
- else:
309
- loss_keypoints_list.append(0)
310
-
311
-
312
- losses.backward()
313
- optimizer.step()
314
-
315
- total_loss += losses.item()
316
-
317
- # Update the description with the current loss
318
- progress_bar.set_description(f'Epoch {epoch+1+start_epoch}, Loss: {losses.item():.4f}')
319
-
320
- # Calculate average loss
321
- avg_loss = total_loss / len(data_loader)
322
-
323
- epoch_avg_losses.append(avg_loss)
324
- epoch_avg_loss_classifier.append(np.mean(loss_classifier_list))
325
- epoch_avg_loss_box_reg.append(np.mean(loss_box_reg_list))
326
- epoch_avg_loss_objectness.append(np.mean(loss_objectness_list))
327
- epoch_avg_loss_rpn_box_reg.append(np.mean(loss_rpn_box_reg_list))
328
- epoch_avg_loss_keypoints.append(np.mean(loss_keypoints_list))
329
 
 
330
 
331
- # Evaluate the model on the test set
332
- if eval_metric == 'loss':
333
- labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0,0,0,0,0,0
334
- avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
335
- print(f"Epoch {epoch+1+start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
336
- else:
337
- avg_test_loss = 0
338
- labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader,score_threshold=0.5, iou_threshold=0.5, distance_threshold=10, key_correction=False, model_type=model_type)
339
- print(f"Epoch {epoch+1+start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ")
340
- avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
341
- print(f"Epoch {epoch+1+start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
342
-
343
- print(f"Time: {time.time() - start:.2f} [s]")
344
-
345
- if eval_metric == 'f1_score':
346
- metric_used = f1_score
347
- elif eval_metric == 'precision':
348
- metric_used = precision
349
- elif eval_metric == 'recall':
350
- metric_used = recall
351
- else:
352
- metric_used = -avg_test_loss
353
-
354
- # Check if this epoch's model has the lowest average loss
355
- if metric_used > best_metrics:
356
- best_metrics = metric_used
357
- best_epoch = epoch+1+start_epoch
358
- best_model_state = copy.deepcopy(model.state_dict())
359
-
360
- if epoch>0 and f1_score>early_stop_f1_score:
361
- same+=1
362
-
363
- epoch_precision.append(precision)
364
- epoch_recall.append(recall)
365
- epoch_f1_score.append(f1_score)
366
- epoch_test_loss.append(avg_test_loss)
367
-
368
- name_model = f"model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}"
369
- metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
370
-
371
- if same >=1 :
372
- torch.save(best_model_state, './models/'+ name_model +'.pth')
373
- write_results(name_model,metrics_list,start_epoch)
374
- break
375
-
376
- if (epoch+1+start_epoch) % 5 == 0:
377
- torch.save(best_model_state, './models/'+ name_model +'.pth')
378
- model.load_state_dict(best_model_state)
379
- write_results(name_model,metrics_list,start_epoch)
380
 
381
- if avg_test_loss > previous_test_loss:
382
- bad_test_loss += 1
383
- previous_test_loss = avg_test_loss
 
 
 
 
 
 
384
 
 
 
385
 
386
- print(f"\n Total time: {(time.time() - start_tot)/60} minutes, Best Epoch is {best_epoch} with an {eval_metric} of {best_metrics:.4f}")
387
- if best_model_state:
388
- torch.save(best_model_state, './models/'+ name_model +'.pth')
389
- model.load_state_dict(best_model_state)
390
- write_results(name_model,metrics_list,start_epoch)
391
- print(f"Name of the best model: model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}")
 
 
 
 
 
 
 
 
392
 
393
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  from modules.utils import write_results
16
 
17
 
 
 
18
  def get_arrow_model(num_classes, num_keypoints=2):
19
  """
20
  Configures and returns a modified Keypoint R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes and keypoints.
 
25
 
26
  Returns:
27
  - model (torch.nn.Module): The modified Keypoint R-CNN model.
 
 
 
 
 
 
 
 
28
  """
29
  # Load a model pre-trained on COCO, initialized without pre-trained weights
30
  model = keypointrcnn_resnet50_fpn(weights=None)
 
62
 
63
  return model
64
 
65
+ def prepare_model(dict, opti, learning_rate=0.0003, model_to_load=None, model_type='object'):
66
+ """
67
+ Prepares the model and optimizer for training.
 
 
 
68
 
69
+ Parameters:
70
+ - dict (dict): Dictionary of classes.
71
+ - opti (str): Optimizer type ('SGD' or 'Adam').
72
+ - learning_rate (float): Learning rate for the optimizer.
73
+ - model_to_load (str, optional): Name of the model to load.
74
+ - model_type (str): Type of model to prepare ('object' or 'arrow').
75
 
76
+ Returns:
77
+ - model (torch.nn.Module): The prepared model.
78
+ - optimizer (torch.optim.Optimizer): The configured optimizer.
79
+ - device (torch.device): The device (CPU or CUDA) on which to perform training.
80
+ """
81
+ # Adjusted to pass the class_dict directly
82
+ if model_type == 'object':
83
+ model = get_faster_rcnn_model(len(dict))
84
+ elif model_type == 'arrow':
85
+ model = get_arrow_model(len(dict), 2)
86
 
87
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
88
+ # Load the model weights
89
+ if model_to_load:
90
+ model.load_state_dict(torch.load('./models/' + model_to_load + '.pth', map_location=device))
91
+ print(f"Model '{model_to_load}' loaded")
 
 
 
92
 
93
+ model.to(device)
94
 
95
+ if opti == 'SGD':
96
+ optimizer = SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)
97
+ elif opti == 'Adam':
98
+ optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.00056, eps=1e-08, betas=(0.9, 0.999))
99
+ else:
100
+ print('Optimizer not found')
101
 
102
+ return model, optimizer, device
 
 
 
103
 
 
 
 
104
 
105
  def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False):
106
+ """
107
+ Evaluate the loss of the model on a validation dataset.
108
+
109
+ Parameters:
110
+ - model (torch.nn.Module): The model to evaluate.
111
+ - data_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset.
112
+ - device (torch.device): Device to perform evaluation on.
113
+ - loss_config (dict, optional): Configuration specifying which losses to use.
114
+ - print_losses (bool): Whether to print individual loss components.
115
+
116
+ Returns:
117
+ - float: Average loss over the validation dataset.
118
+ """
119
  model.train() # Set the model to evaluation mode
120
  total_loss = 0
121
 
 
180
  avg_loss_keypoints = np.mean(loss_keypoints_list)
181
 
182
  if print_losses:
183
+ print(f"Average Loss: {avg_loss:.4f}")
184
+ print(f"Average Classifier Loss: {avg_loss_classifier:.4f}")
185
+ print(f"Average Box Regression Loss: {avg_loss_box_reg:.4f}")
186
+ print(f"Average Objectness Loss: {avg_loss_objectness:.4f}")
187
+ print(f"Average RPN Box Regression Loss: {avg_loss_rpn_box_reg:.4f}")
188
+ print(f"Average Keypoints Loss: {avg_loss_keypoints:.4f}")
189
 
190
  return avg_loss
191
 
 
194
  optimizer, model_to_load=None, change_learning_rate=100, start_key=100,
195
  parameters=None, blur_prob=0.02,
196
  score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97,
197
+ information_training='training', start_epoch=0, loss_config=None, model_type='object',
198
  eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
199
+ """
200
+ Train the model over a specified number of epochs.
201
+
202
+ Parameters:
203
+ - num_epochs (int): Number of epochs to train for.
204
+ - model (torch.nn.Module): Model to train.
205
+ - data_loader (torch.utils.data.DataLoader): DataLoader for the training dataset.
206
+ - subset_test_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset.
207
+ - optimizer (torch.optim.Optimizer): Optimizer to use for training.
208
+ - model_to_load (str, optional): Name of the model to load.
209
+ - change_learning_rate (int): Epoch interval to change the learning rate.
210
+ - start_key (int): Epoch to start training keypoints.
211
+ - parameters (dict, optional): Additional training parameters.
212
+ - blur_prob (float): Probability of applying blur augmentation.
213
+ - score_threshold (float): Score threshold for evaluation.
214
+ - iou_threshold (float): IoU threshold for evaluation.
215
+ - early_stop_f1_score (float): F1 score threshold for early stopping.
216
+ - information_training (str): Information about the training.
217
+ - start_epoch (int): Starting epoch number.
218
+ - loss_config (dict, optional): Configuration specifying which losses to use.
219
+ - model_type (str): Type of model ('object' or 'arrow').
220
+ - eval_metric (str): Evaluation metric ('f1_score', 'precision', 'recall', or 'loss').
221
+ - device (torch.device): Device to perform training on.
222
 
223
+ Returns:
224
+ - model (torch.nn.Module): Trained model.
225
+ """
226
+ model.train()
227
+
228
+ if loss_config is None:
229
+ print('No loss config found, all losses will be used.')
230
+ else:
231
+ # Print the list of the losses that will be used
232
+ print('The following losses will be used: ', end='')
233
+ for key, value in loss_config.items():
234
+ if value:
235
+ print(key, end=", ")
236
+ print()
237
+
238
+ # Initialize lists to store epoch-wise average losses
239
+ epoch_avg_losses = []
240
+ epoch_avg_loss_classifier = []
241
+ epoch_avg_loss_box_reg = []
242
+ epoch_avg_loss_objectness = []
243
+ epoch_avg_loss_rpn_box_reg = []
244
+ epoch_avg_loss_keypoints = []
245
+ epoch_precision = []
246
+ epoch_recall = []
247
+ epoch_f1_score = []
248
+ epoch_test_loss = []
249
+
250
+ start_tot = time.time()
251
+ best_metrics = -1000
252
+ best_epoch = 0
253
+ best_model_state = None
254
+ same = 0
255
+ learning_rate = optimizer.param_groups[0]['lr']
256
+ bad_test_loss = 0
257
+ previous_test_loss = 1000
258
+
259
+ if parameters is not None:
260
+ batch_size, crop_prob, rotate_90_proba, h_flip_prob, v_flip_prob, max_rotate_deg, rotate_proba, keep_ratio = parameters.values()
261
+
262
+ print(f"Let's go training {model_type} model with {num_epochs} epochs!")
263
+ if parameters is not None:
264
+ print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, H flip prob: {h_flip_prob}, V flip prob: {v_flip_prob}, Max rotate deg: {max_rotate_deg}, Rotate proba: {rotate_proba}, Rotate 90 proba: {rotate_90_proba}, Keep ratio: {keep_ratio}")
265
+
266
+ for epoch in range(num_epochs):
267
+ if (epoch > 0 and (epoch) % change_learning_rate == 0) or bad_test_loss >= 3:
268
+ learning_rate = 0.7 * learning_rate
269
+ optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999))
270
+ if best_model_state is not None:
271
+ model.load_state_dict(best_model_state)
272
+ print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}')
273
+ bad_test_loss = 0
274
+ if epoch > 0 and (epoch) == start_key:
275
+ print("Now it's training Keypoints also")
276
+ loss_config['loss_keypoint'] = True
277
+ for name, param in model.named_parameters():
278
+ if 'keypoint' in name:
279
+ param.requires_grad = True
280
+
281
+ model.train()
282
+ start = time.time()
283
+ total_loss = 0
284
+
285
+ # Initialize lists to keep track of individual losses
286
+ loss_classifier_list = []
287
+ loss_box_reg_list = []
288
+ loss_objectness_list = []
289
+ loss_rpn_box_reg_list = []
290
+ loss_keypoints_list = []
291
+
292
+ # Create a tqdm progress bar
293
+ progress_bar = tqdm(data_loader, desc=f'Epoch {epoch + 1 + start_epoch}')
294
+
295
+ for images, targets_im in progress_bar:
296
+ images = [image.to(device) for image in images]
297
+ targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
+ optimizer.zero_grad()
300
 
301
+ loss_dict = model(images, targets)
302
+ # Inside the training loop where losses are calculated:
303
+ losses = 0
304
+ if loss_config is not None:
305
+ for key, loss in loss_dict.items():
306
+ if loss_config.get(key, False):
307
+ if key == 'loss_classifier':
308
+ loss *= 3
309
+ losses += loss
310
+ else:
311
+ losses = sum(loss for key, loss in loss_dict.items())
312
+
313
+ # Collect individual losses
314
+ if loss_dict['loss_classifier']:
315
+ loss_classifier_list.append(loss_dict['loss_classifier'].item())
316
+ else:
317
+ loss_classifier_list.append(0)
318
+
319
+ if loss_dict['loss_box_reg']:
320
+ loss_box_reg_list.append(loss_dict['loss_box_reg'].item())
321
+ else:
322
+ loss_box_reg_list.append(0)
323
+
324
+ if loss_dict['loss_objectness']:
325
+ loss_objectness_list.append(loss_dict['loss_objectness'].item())
326
+ else:
327
+ loss_objectness_list.append(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
+ if loss_dict['loss_rpn_box_reg']:
330
+ loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item())
331
+ else:
332
+ loss_rpn_box_reg_list.append(0)
333
+
334
+ if 'loss_keypoint' in loss_dict:
335
+ loss_keypoints_list.append(loss_dict['loss_keypoint'].item())
336
+ else:
337
+ loss_keypoints_list.append(0)
338
 
339
+ losses.backward()
340
+ optimizer.step()
341
 
342
+ total_loss += losses.item()
343
+
344
+ # Update the description with the current loss
345
+ progress_bar.set_description(f'Epoch {epoch + 1 + start_epoch}, Loss: {losses.item():.4f}')
346
+
347
+ # Calculate average loss
348
+ avg_loss = total_loss / len(data_loader)
349
+
350
+ epoch_avg_losses.append(avg_loss)
351
+ epoch_avg_loss_classifier.append(np.mean(loss_classifier_list))
352
+ epoch_avg_loss_box_reg.append(np.mean(loss_box_reg_list))
353
+ epoch_avg_loss_objectness.append(np.mean(loss_objectness_list))
354
+ epoch_avg_loss_rpn_box_reg.append(np.mean(loss_rpn_box_reg_list))
355
+ epoch_avg_loss_keypoints.append(np.mean(loss_keypoints_list))
356
 
357
+ # Evaluate the model on the test set
358
+ if eval_metric == 'loss':
359
+ labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0, 0, 0, 0, 0, 0
360
+ avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
361
+ print(f"Epoch {epoch + 1 + start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
362
+ else:
363
+ avg_test_loss = 0
364
+ labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=10, key_correction=False, model_type=model_type)
365
+ print(f"Epoch {epoch + 1 + start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ")
366
+ avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
367
+ print(f"Epoch {epoch + 1 + start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
368
+
369
+ print(f"Time: {time.time() - start:.2f} [s]")
370
+
371
+ if eval_metric == 'f1_score':
372
+ metric_used = f1_score
373
+ elif eval_metric == 'precision':
374
+ metric_used = precision
375
+ elif eval_metric == 'recall':
376
+ metric_used = recall
377
+ else:
378
+ metric_used = -avg_test_loss
379
+
380
+ # Check if this epoch's model has the lowest average loss
381
+ if metric_used > best_metrics:
382
+ best_metrics = metric_used
383
+ best_epoch = epoch + 1 + start_epoch
384
+ best_model_state = copy.deepcopy(model.state_dict())
385
+
386
+ if epoch > 0 and f1_score > early_stop_f1_score:
387
+ same += 1
388
+
389
+ epoch_precision.append(precision)
390
+ epoch_recall.append(recall)
391
+ epoch_f1_score.append(f1_score)
392
+ epoch_test_loss.append(avg_test_loss)
393
+
394
+ name_model = f"model_{type(optimizer).__name__}_{epoch + 1 + start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob * 10)}_crop0{int(crop_prob * 10)}_flip0{int(h_flip_prob * 10)}_rotate0{int(rotate_proba * 10)}_{information_training}"
395
+ metrics_list = [epoch_avg_losses, epoch_avg_loss_classifier, epoch_avg_loss_box_reg, epoch_avg_loss_objectness, epoch_avg_loss_rpn_box_reg, epoch_avg_loss_keypoints, epoch_precision, epoch_recall, epoch_f1_score, epoch_test_loss]
396
+
397
+ if same >= 1:
398
+ torch.save(best_model_state, './models/' + name_model + '.pth')
399
+ write_results(name_model, metrics_list, start_epoch)
400
+ break
401
+
402
+ if (epoch + 1 + start_epoch) % 5 == 0:
403
+ torch.save(best_model_state, './models/' + name_model + '.pth')
404
+ model.load_state_dict(best_model_state)
405
+ write_results(name_model, metrics_list, start_epoch)
406
+
407
+ if avg_test_loss > previous_test_loss:
408
+ bad_test_loss += 1
409
+ previous_test_loss = avg_test_loss
410
+
411
+ print(f"\n Total time: {(time.time() - start_tot) / 60} minutes, Best Epoch is {best_epoch} with an {eval_metric} of {best_metrics:.4f}")
412
+ if best_model_state:
413
+ torch.save(best_model_state, './models/' + name_model + '.pth')
414
+ model.load_state_dict(best_model_state)
415
+ write_results(name_model, metrics_list, start_epoch)
416
+ print(f"Name of the best model: model_{type(optimizer).__name__}_{epoch + 1 + start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob * 10)}_crop0{int(crop_prob * 10)}_flip0{int(h_flip_prob * 10)}_rotate0{int(rotate_proba * 10)}_{information_training}")
417
+
418
+ return model
modules/utils.py CHANGED
@@ -1,59 +1,11 @@
1
- from torchvision.models.detection import keypointrcnn_resnet50_fpn
2
- from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
3
- from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
4
- from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
5
- import random
6
  import torch
7
- from torch.utils.data import Dataset
8
  import torchvision.transforms.functional as F
9
  import numpy as np
10
- from torch.utils.data.dataloader import default_collate
11
  import cv2
12
  import matplotlib.pyplot as plt
13
- from torch.utils.data import DataLoader, Subset, ConcatDataset
14
  import streamlit as st
15
 
16
-
17
- """object_dict = {
18
- 0: 'background',
19
- 1: 'task',
20
- 2: 'exclusiveGateway',
21
- 3: 'eventBasedGateway',
22
- 4: 'event',
23
- 5: 'messageEvent',
24
- 6: 'timerEvent',
25
- 7: 'dataObject',
26
- 8: 'dataStore',
27
- 9: 'pool',
28
- 10: 'lane',
29
- }
30
-
31
-
32
- arrow_dict = {
33
- 0: 'background',
34
- 1: 'sequenceFlow',
35
- 2: 'dataAssociation',
36
- 3: 'messageFlow',
37
- }
38
-
39
- class_dict = {
40
- 0: 'background',
41
- 1: 'task',
42
- 2: 'exclusiveGateway',
43
- 3: 'eventBasedGateway',
44
- 4: 'event',
45
- 5: 'messageEvent',
46
- 6: 'timerEvent',
47
- 7: 'dataObject',
48
- 8: 'dataStore',
49
- 9: 'pool',
50
- 10: 'lane',
51
- 11: 'sequenceFlow',
52
- 12: 'dataAssociation',
53
- 13: 'messageFlow',
54
- }"""
55
-
56
-
57
  object_dict = {
58
  0: 'background',
59
  1: 'task',
@@ -96,7 +48,6 @@ class_dict = {
96
  15: 'messageFlow',
97
  }
98
 
99
-
100
  def is_inside(box1, box2):
101
  """Check if the center of box1 is inside box2."""
102
  x_center = (box1[0] + box1[2]) / 2
@@ -107,51 +58,31 @@ def is_vertical(box):
107
  """Determine if the text in the bounding box is vertically aligned."""
108
  width = box[2] - box[0]
109
  height = box[3] - box[1]
110
- return (height > 2*width)
111
 
112
  def rescale_boxes(scale, boxes):
 
113
  for i in range(len(boxes)):
114
- boxes[i] = [boxes[i][0]*scale,
115
- boxes[i][1]*scale,
116
- boxes[i][2]*scale,
117
- boxes[i][3]*scale]
118
  return boxes
119
 
120
  def iou(box1, box2):
121
- # Calcule l'intersection des deux boîtes englobantes
122
  inter_box = [max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])]
123
  inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
124
-
125
- # Calcule l'union des deux boîtes englobantes
126
  box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
127
  box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
128
  union_area = box1_area + box2_area - inter_area
129
-
130
  return inter_area / union_area
131
 
132
  def proportion_inside(box1, box2):
133
- # Calculate the areas of both boxes
134
  box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
135
  box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
136
-
137
- # Determine the bigger and smaller boxes
138
- if box1_area > box2_area:
139
- big_box = box1
140
- small_box = box2
141
- else:
142
- big_box = box2
143
- small_box = box1
144
-
145
- # Calculate the intersection of the two bounding boxes
146
  inter_box = [max(small_box[0], big_box[0]), max(small_box[1], big_box[1]), min(small_box[2], big_box[2]), min(small_box[3], big_box[3])]
147
  inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
148
-
149
- # Calculate the proportion of the smaller box inside the bigger box
150
- if (small_box[2] - small_box[0]) * (small_box[3] - small_box[1]) == 0:
151
- return 0
152
  proportion = inter_area / ((small_box[2] - small_box[0]) * (small_box[3] - small_box[1]))
153
-
154
- # Ensure the proportion is at most 100%
155
  return min(proportion, 1.0)
156
 
157
  def resize_boxes(boxes, original_size, target_size):
@@ -168,20 +99,15 @@ def resize_boxes(boxes, original_size, target_size):
168
  """
169
  orig_width, orig_height = original_size
170
  target_width, target_height = target_size
171
-
172
- # Calculate the ratios for width and height
173
  width_ratio = target_width / orig_width
174
  height_ratio = target_height / orig_height
175
-
176
- # Apply the ratios to the bounding boxes
177
  boxes[:, 0] *= width_ratio
178
  boxes[:, 1] *= height_ratio
179
  boxes[:, 2] *= width_ratio
180
  boxes[:, 3] *= height_ratio
181
-
182
  return boxes
183
 
184
- def resize_keypoints(keypoints: np.ndarray, original_size: tuple, target_size: tuple) -> np.ndarray:
185
  """
186
  Resize keypoints based on the original and target dimensions of an image.
187
 
@@ -192,40 +118,38 @@ def resize_keypoints(keypoints: np.ndarray, original_size: tuple, target_size: t
192
 
193
  Returns:
194
  - np.ndarray: The resized keypoints.
195
-
196
- Explanation:
197
- The function calculates the ratio of the target dimensions to the original dimensions.
198
- It then applies these ratios to the x and y coordinates of each keypoint to scale them
199
- appropriately to the target image size.
200
  """
201
-
202
  orig_width, orig_height = original_size
203
  target_width, target_height = target_size
204
-
205
- # Calculate the ratios for width and height scaling
206
  width_ratio = target_width / orig_width
207
  height_ratio = target_height / orig_height
208
-
209
- # Apply the scaling ratios to the x and y coordinates of each keypoint
210
- keypoints[:, 0] *= width_ratio # Scale x coordinates
211
- keypoints[:, 1] *= height_ratio # Scale y coordinates
212
-
213
  return keypoints
214
 
215
-
216
- def write_results(name_model,metrics_list,start_epoch):
217
- with open('./results/'+ name_model+ '.txt', 'w') as f:
218
  for i in range(len(metrics_list[0])):
219
- f.write(f"{i+1+start_epoch},{metrics_list[0][i]},{metrics_list[1][i]},{metrics_list[2][i]},{metrics_list[3][i]},{metrics_list[4][i]},{metrics_list[5][i]},{metrics_list[6][i]},{metrics_list[7][i]},{metrics_list[8][i]},{metrics_list[9][i]} \n")
220
-
221
 
222
  def find_other_keypoint(idx, keypoints, boxes):
 
 
 
 
 
 
 
 
 
 
 
223
  box = boxes[idx]
224
- key1,key2 = keypoints[idx]
225
  x1, y1, x2, y2 = box
226
  center = ((x1 + x2) // 2, (y1 + y2) // 2)
227
  average_keypoint = (key1 + key2) // 2
228
- #find the opposite keypoint to the center
229
  if average_keypoint[0] < center[0]:
230
  x = center[0] + abs(center[0] - average_keypoint[0])
231
  else:
@@ -235,7 +159,6 @@ def find_other_keypoint(idx, keypoints, boxes):
235
  else:
236
  y = center[1] - abs(center[1] - average_keypoint[1])
237
  return x, y, average_keypoint[0], average_keypoint[1]
238
-
239
 
240
  def filter_overlap_boxes(boxes, scores, labels, keypoints, iou_threshold=0.5):
241
  """
@@ -251,47 +174,28 @@ def filter_overlap_boxes(boxes, scores, labels, keypoints, iou_threshold=0.5):
251
  Returns:
252
  - tuple: Filtered boxes, scores, labels, and keypoints.
253
  """
254
- # Calculate the area of each bounding box to use in IoU calculation.
255
  areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
256
-
257
- # Sort the indices of the boxes based on their scores in descending order.
258
  order = scores.argsort()[::-1]
259
-
260
- keep = [] # List to store indices of boxes to keep.
261
-
262
  while order.size > 0:
263
- # Take the first index (highest score) from the sorted list.
264
  i = order[0]
265
- keep.append(i) # Add this index to 'keep' list.
266
-
267
- # Compute the coordinates of the intersection rectangle.
268
  xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0])
269
  yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1])
270
  xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2])
271
  yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3])
272
-
273
- # Compute the area of the intersection rectangle.
274
  w = np.maximum(0.0, xx2 - xx1)
275
  h = np.maximum(0.0, yy2 - yy1)
276
  inter = w * h
277
-
278
- # Calculate IoU and find boxes with IoU less than the threshold to keep.
279
  iou = inter / (areas[i] + areas[order[1:]] - inter)
280
  inds = np.where(iou <= iou_threshold)[0]
281
-
282
- # Update the list of box indices to consider in the next iteration.
283
- order = order[inds + 1] # Skip the first element since it's already included in 'keep'.
284
-
285
- # Use the indices in 'keep' to select the boxes, scores, labels, and keypoints to return.
286
  boxes = boxes[keep]
287
  scores = scores[keep]
288
  labels = labels[keep]
289
  keypoints = keypoints[keep]
290
-
291
  return boxes, scores, labels, keypoints
292
 
293
-
294
-
295
  def draw_annotations(image,
296
  target=None,
297
  prediction=None,
@@ -312,7 +216,7 @@ def draw_annotations(image,
312
  only_print=None,
313
  axis=False,
314
  return_image=False,
315
- new_size=(1333,800),
316
  resize=False):
317
  """
318
  Draws annotations on images including bounding boxes, keypoints, links, and text.
@@ -328,7 +232,7 @@ def draw_annotations(image,
328
  - draw_boxes (bool): Flag to draw bounding boxes.
329
  - draw_text (bool): Flag to draw text annotations.
330
  - draw_links (bool): Flag to draw links between annotations.
331
- - draw_twins (bool): Flag to draw twins keypoints.
332
  - write_class (bool): Flag to write class names near the annotations.
333
  - write_score (bool): Flag to write scores near the annotations.
334
  - write_text (bool): Flag to write OCR recognized text.
@@ -345,137 +249,119 @@ def draw_annotations(image,
345
  image_copy = image.copy()
346
  scale = max(image.shape[0], image.shape[1]) / 1000
347
 
348
- # Function to draw bounding boxes and keypoints
349
- def draw(data,is_prediction=False):
350
- """ Helper function to draw annotations based on provided data. """
351
-
352
  for i in range(len(data['boxes'])):
 
 
 
 
353
  if is_prediction:
354
- box = data['boxes'][i].tolist()
355
- x1, y1, x2, y2 = box
356
- if resize:
357
- x1, y1, x2, y2 = resize_boxes(np.array([box]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
358
  score = data['scores'][i].item()
359
  if score < score_threshold:
360
  continue
361
- else:
362
- box = data['boxes'][i].tolist()
363
- x1, y1, x2, y2 = box
364
  if draw_boxes:
365
  if only_print is not None:
366
  if data['labels'][i] != list(model_dict.values()).index(only_print):
367
  continue
368
- cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0) if is_prediction else (0, 0, 0), int(2*scale))
369
  if is_prediction and write_score:
370
- cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (100,100, 255), 2)
371
 
372
  if write_class and 'labels' in data:
373
  class_id = data['labels'][i].item()
374
- cv2.putText(image_copy, model_dict[class_id], (int(x1), int(y1) - int(2*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (255, 100, 100), 2)
375
 
376
  if write_idx:
377
- cv2.putText(image_copy, str(i), (int(x1) + int(15*scale), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, 2*scale, (0,0, 0), 2)
378
-
379
 
380
  # Draw keypoints if available
381
  if draw_keypoints and 'keypoints' in data:
382
  if is_prediction and keypoints_correction:
383
  for idx, (key1, key2) in enumerate(data['keypoints']):
384
  if data['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
385
- list(model_dict.values()).index('messageFlow'),
386
- list(model_dict.values()).index('dataAssociation')]:
387
  continue
388
- # Calculate the Euclidean distance between the two keypoints
389
  distance = np.linalg.norm(key1[:2] - key2[:2])
390
-
391
  if distance < 5:
392
- x_new,y_new, x,y = find_other_keypoint(idx, data['keypoints'], data['boxes'])
393
- data['keypoints'][idx][0] = torch.tensor([x_new, y_new,1])
394
- data['keypoints'][idx][1] = torch.tensor([x, y,1])
395
  print("keypoint has been changed")
396
  for i in range(len(data['keypoints'])):
397
  kp = data['keypoints'][i]
398
  for j in range(kp.shape[0]):
399
- if is_prediction and data['labels'][i] != list(model_dict.values()).index('sequenceFlow') and data['labels'][i] != list(model_dict.values()).index('messageFlow') and data['labels'][i] != list(model_dict.values()).index('dataAssociation'):
 
 
400
  continue
401
  if is_prediction:
402
  score = data['scores'][i]
403
  if score < score_threshold:
404
  continue
405
- x,y,v = np.array(kp[j])
406
  if resize:
407
- x, y, v = resize_keypoints(np.array([kp[j]]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
408
  if j == 0:
409
- cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (0, 0, 255), -1)
410
  else:
411
- cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (255, 0, 0), -1)
412
 
413
  # Draw text predictions if available
414
- if (draw_text or write_text) and text_predictions is not None:
415
  for i in range(len(text_predictions[0])):
416
  x1, y1, x2, y2 = text_predictions[0][i]
417
  text = text_predictions[1][i]
418
  if resize:
419
- x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
420
  if draw_text:
421
- cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
422
  if write_text:
423
- cv2.putText(image_copy, text, (int(x1 + int(2*scale)), int((y1+y2)/2) ), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (0,0, 0), 2)
424
-
425
  def draw_with_links(full_prediction):
426
- '''Draws links between objects based on the full prediction data.'''
427
- #check if keypoints detected are the same
428
  if draw_twins and full_prediction is not None:
429
- # Pre-calculate indices for performance
430
- circle_color = (0, 255, 0) # Green color for the circle
431
- circle_radius = int(10 * scale) # Circle radius scaled by image scale
432
-
433
  for idx, (key1, key2) in enumerate(full_prediction['keypoints']):
434
  if full_prediction['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
435
- list(model_dict.values()).index('messageFlow'),
436
- list(model_dict.values()).index('dataAssociation')]:
437
  continue
438
- # Calculate the Euclidean distance between the two keypoints
439
  distance = np.linalg.norm(key1[:2] - key2[:2])
440
  if distance < 10:
441
- x_new,y_new, x,y = find_other_keypoint(idx,full_prediction)
442
  cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1)
443
- cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0,0,0), -1)
444
 
445
- # Draw links between objects
446
- if draw_links==True and full_prediction is not None:
447
  for i, (start_idx, end_idx) in enumerate(full_prediction['links']):
448
  if start_idx is None or end_idx is None:
449
  continue
450
  start_box = full_prediction['boxes'][start_idx]
451
  end_box = full_prediction['boxes'][end_idx]
452
  current_box = full_prediction['boxes'][i]
453
- # Calculate the center of each bounding box
454
  start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2)
455
  end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2)
456
  current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2)
457
- # Draw a line between the centers of the connected objects
458
- cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2*scale))
459
- cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2*scale))
460
 
461
- i+=1
462
 
463
- # Draw GT annotations
464
  if target is not None:
465
  draw(target, is_prediction=False)
466
- # Draw predictions
467
  if prediction is not None:
468
- #prediction = prediction[0]
469
  draw(prediction, is_prediction=True)
470
- # Draw links with full predictions
471
  if full_prediction is not None:
472
  draw_with_links(full_prediction)
473
 
474
- # Display the image
475
  image_copy = cv2.cvtColor(image_copy, cv2.COLOR_BGR2RGB)
476
  plt.figure(figsize=(12, 12))
477
  plt.imshow(image_copy)
478
- if axis==False:
479
  plt.axis('off')
480
  plt.show()
481
 
@@ -496,28 +382,24 @@ def find_closest_object(keypoint, boxes, labels):
496
  closest_object_idx = None
497
  best_point = None
498
  min_distance = float('inf')
499
- # Iterate over each bounding box
500
  for i, box in enumerate(boxes):
501
  if labels[i] in [list(class_dict.values()).index('sequenceFlow'),
502
  list(class_dict.values()).index('messageFlow'),
503
  list(class_dict.values()).index('dataAssociation'),
504
- #list(class_dict.values()).index('pool'),
505
  list(class_dict.values()).index('lane')]:
506
  continue
507
  x1, y1, x2, y2 = box
508
 
509
- top = ((x1+x2)/2, y1)
510
- bottom = ((x1+x2)/2, y2)
511
- left = (x1, (y1+y2)/2)
512
- right = (x2, (y1+y2)/2)
513
- points = [left, top , right, bottom]
514
 
515
- pos_dict = {0:'left', 1:'top', 2:'right', 3:'bottom'}
516
 
517
- # Calculate the distance between the keypoint and the center of the bounding box
518
- for pos, (point) in enumerate(points):
519
  distance = np.linalg.norm(keypoint[:2] - point)
520
- # Update the closest object index if this object is closer
521
  if distance < min_distance:
522
  min_distance = distance
523
  closest_object_idx = i
@@ -525,9 +407,10 @@ def find_closest_object(keypoint, boxes, labels):
525
 
526
  return closest_object_idx, best_point
527
 
528
-
529
  def error(text='There is an error in the detection'):
 
530
  st.error(text, icon="🚨")
531
 
532
  def warning(text='Some element are maybe not detected, verify the results, try to modify the parameters or try to add it in the method and style step.'):
 
533
  st.warning(text, icon="⚠️")
 
 
 
 
 
 
1
  import torch
 
2
  import torchvision.transforms.functional as F
3
  import numpy as np
 
4
  import cv2
5
  import matplotlib.pyplot as plt
 
6
  import streamlit as st
7
 
8
+ # Define dictionaries to map class indices to their corresponding names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  object_dict = {
10
  0: 'background',
11
  1: 'task',
 
48
  15: 'messageFlow',
49
  }
50
 
 
51
  def is_inside(box1, box2):
52
  """Check if the center of box1 is inside box2."""
53
  x_center = (box1[0] + box1[2]) / 2
 
58
  """Determine if the text in the bounding box is vertically aligned."""
59
  width = box[2] - box[0]
60
  height = box[3] - box[1]
61
+ return (height > 2 * width)
62
 
63
  def rescale_boxes(scale, boxes):
64
+ """Rescale the bounding boxes by a given scale factor."""
65
  for i in range(len(boxes)):
66
+ boxes[i] = [boxes[i][0] * scale, boxes[i][1] * scale, boxes[i][2] * scale, boxes[i][3] * scale]
 
 
 
67
  return boxes
68
 
69
  def iou(box1, box2):
70
+ """Calculate the Intersection over Union (IoU) of two bounding boxes."""
71
  inter_box = [max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])]
72
  inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
 
 
73
  box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
74
  box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
75
  union_area = box1_area + box2_area - inter_area
 
76
  return inter_area / union_area
77
 
78
  def proportion_inside(box1, box2):
79
+ """Calculate the proportion of the smaller box inside the larger box."""
80
  box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
81
  box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
82
+ big_box, small_box = (box1, box2) if box1_area > box2_area else (box2, box1)
 
 
 
 
 
 
 
 
 
83
  inter_box = [max(small_box[0], big_box[0]), max(small_box[1], big_box[1]), min(small_box[2], big_box[2]), min(small_box[3], big_box[3])]
84
  inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
 
 
 
 
85
  proportion = inter_area / ((small_box[2] - small_box[0]) * (small_box[3] - small_box[1]))
 
 
86
  return min(proportion, 1.0)
87
 
88
  def resize_boxes(boxes, original_size, target_size):
 
99
  """
100
  orig_width, orig_height = original_size
101
  target_width, target_height = target_size
 
 
102
  width_ratio = target_width / orig_width
103
  height_ratio = target_height / orig_height
 
 
104
  boxes[:, 0] *= width_ratio
105
  boxes[:, 1] *= height_ratio
106
  boxes[:, 2] *= width_ratio
107
  boxes[:, 3] *= height_ratio
 
108
  return boxes
109
 
110
+ def resize_keypoints(keypoints, original_size, target_size):
111
  """
112
  Resize keypoints based on the original and target dimensions of an image.
113
 
 
118
 
119
  Returns:
120
  - np.ndarray: The resized keypoints.
 
 
 
 
 
121
  """
 
122
  orig_width, orig_height = original_size
123
  target_width, target_height = target_size
 
 
124
  width_ratio = target_width / orig_width
125
  height_ratio = target_height / orig_height
126
+ keypoints[:, 0] *= width_ratio
127
+ keypoints[:, 1] *= height_ratio
 
 
 
128
  return keypoints
129
 
130
+ def write_results(name_model, metrics_list, start_epoch):
131
+ """Write training results to a text file."""
132
+ with open('./results/' + name_model + '.txt', 'w') as f:
133
  for i in range(len(metrics_list[0])):
134
+ f.write(f"{i + 1 + start_epoch},{metrics_list[0][i]},{metrics_list[1][i]},{metrics_list[2][i]},{metrics_list[3][i]},{metrics_list[4][i]},{metrics_list[5][i]},{metrics_list[6][i]},{metrics_list[7][i]},{metrics_list[8][i]},{metrics_list[9][i]} \n")
 
135
 
136
  def find_other_keypoint(idx, keypoints, boxes):
137
+ """
138
+ Find the opposite keypoint to the center of the box.
139
+
140
+ Parameters:
141
+ - idx (int): The index of the box and keypoints.
142
+ - keypoints (np.ndarray): The array of keypoints.
143
+ - boxes (np.ndarray): The array of bounding boxes.
144
+
145
+ Returns:
146
+ - tuple: The coordinates of the new keypoint and the average keypoint.
147
+ """
148
  box = boxes[idx]
149
+ key1, key2 = keypoints[idx]
150
  x1, y1, x2, y2 = box
151
  center = ((x1 + x2) // 2, (y1 + y2) // 2)
152
  average_keypoint = (key1 + key2) // 2
 
153
  if average_keypoint[0] < center[0]:
154
  x = center[0] + abs(center[0] - average_keypoint[0])
155
  else:
 
159
  else:
160
  y = center[1] - abs(center[1] - average_keypoint[1])
161
  return x, y, average_keypoint[0], average_keypoint[1]
 
162
 
163
  def filter_overlap_boxes(boxes, scores, labels, keypoints, iou_threshold=0.5):
164
  """
 
174
  Returns:
175
  - tuple: Filtered boxes, scores, labels, and keypoints.
176
  """
 
177
  areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
 
 
178
  order = scores.argsort()[::-1]
179
+ keep = []
 
 
180
  while order.size > 0:
 
181
  i = order[0]
182
+ keep.append(i)
 
 
183
  xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0])
184
  yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1])
185
  xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2])
186
  yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3])
 
 
187
  w = np.maximum(0.0, xx2 - xx1)
188
  h = np.maximum(0.0, yy2 - yy1)
189
  inter = w * h
 
 
190
  iou = inter / (areas[i] + areas[order[1:]] - inter)
191
  inds = np.where(iou <= iou_threshold)[0]
192
+ order = order[inds + 1]
 
 
 
 
193
  boxes = boxes[keep]
194
  scores = scores[keep]
195
  labels = labels[keep]
196
  keypoints = keypoints[keep]
 
197
  return boxes, scores, labels, keypoints
198
 
 
 
199
  def draw_annotations(image,
200
  target=None,
201
  prediction=None,
 
216
  only_print=None,
217
  axis=False,
218
  return_image=False,
219
+ new_size=(1333, 800),
220
  resize=False):
221
  """
222
  Draws annotations on images including bounding boxes, keypoints, links, and text.
 
232
  - draw_boxes (bool): Flag to draw bounding boxes.
233
  - draw_text (bool): Flag to draw text annotations.
234
  - draw_links (bool): Flag to draw links between annotations.
235
+ - draw_twins (bool): Flag to draw twin keypoints.
236
  - write_class (bool): Flag to write class names near the annotations.
237
  - write_score (bool): Flag to write scores near the annotations.
238
  - write_text (bool): Flag to write OCR recognized text.
 
249
  image_copy = image.copy()
250
  scale = max(image.shape[0], image.shape[1]) / 1000
251
 
252
+ # Helper function to draw annotations based on provided data
253
+ def draw(data, is_prediction=False):
 
 
254
  for i in range(len(data['boxes'])):
255
+ box = data['boxes'][i].tolist()
256
+ x1, y1, x2, y2 = box
257
+ if resize:
258
+ x1, y1, x2, y2 = resize_boxes(np.array([box]), new_size, (image_copy.shape[1], image_copy.shape[0]))[0]
259
  if is_prediction:
 
 
 
 
260
  score = data['scores'][i].item()
261
  if score < score_threshold:
262
  continue
 
 
 
263
  if draw_boxes:
264
  if only_print is not None:
265
  if data['labels'][i] != list(model_dict.values()).index(only_print):
266
  continue
267
+ cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0) if is_prediction else (0, 0, 0), int(2 * scale))
268
  if is_prediction and write_score:
269
+ cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15 * scale)), cv2.FONT_HERSHEY_SIMPLEX, scale / 2, (100, 100, 255), 2)
270
 
271
  if write_class and 'labels' in data:
272
  class_id = data['labels'][i].item()
273
+ cv2.putText(image_copy, model_dict[class_id], (int(x1), int(y1) - int(2 * scale)), cv2.FONT_HERSHEY_SIMPLEX, scale / 2, (255, 100, 100), 2)
274
 
275
  if write_idx:
276
+ cv2.putText(image_copy, str(i), (int(x1) + int(15 * scale), int(y1) + int(15 * scale)), cv2.FONT_HERSHEY_SIMPLEX, 2 * scale, (0, 0, 0), 2)
 
277
 
278
  # Draw keypoints if available
279
  if draw_keypoints and 'keypoints' in data:
280
  if is_prediction and keypoints_correction:
281
  for idx, (key1, key2) in enumerate(data['keypoints']):
282
  if data['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
283
+ list(model_dict.values()).index('messageFlow'),
284
+ list(model_dict.values()).index('dataAssociation')]:
285
  continue
 
286
  distance = np.linalg.norm(key1[:2] - key2[:2])
 
287
  if distance < 5:
288
+ x_new, y_new, x, y = find_other_keypoint(idx, data['keypoints'], data['boxes'])
289
+ data['keypoints'][idx][0] = torch.tensor([x_new, y_new, 1])
290
+ data['keypoints'][idx][1] = torch.tensor([x, y, 1])
291
  print("keypoint has been changed")
292
  for i in range(len(data['keypoints'])):
293
  kp = data['keypoints'][i]
294
  for j in range(kp.shape[0]):
295
+ if is_prediction and data['labels'][i] not in [list(model_dict.values()).index('sequenceFlow'),
296
+ list(model_dict.values()).index('messageFlow'),
297
+ list(model_dict.values()).index('dataAssociation')]:
298
  continue
299
  if is_prediction:
300
  score = data['scores'][i]
301
  if score < score_threshold:
302
  continue
303
+ x, y, v = np.array(kp[j])
304
  if resize:
305
+ x, y, v = resize_keypoints(np.array([kp[j]]), new_size, (image_copy.shape[1], image_copy.shape[0]))[0]
306
  if j == 0:
307
+ cv2.circle(image_copy, (int(x), int(y)), int(5 * scale), (0, 0, 255), -1)
308
  else:
309
+ cv2.circle(image_copy, (int(x), int(y)), int(5 * scale), (255, 0, 0), -1)
310
 
311
  # Draw text predictions if available
312
+ if (draw_text or write_text) and text_predictions is not None:
313
  for i in range(len(text_predictions[0])):
314
  x1, y1, x2, y2 = text_predictions[0][i]
315
  text = text_predictions[1][i]
316
  if resize:
317
+ x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), new_size, (image_copy.shape[1], image_copy.shape[0]))[0]
318
  if draw_text:
319
+ cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2 * scale))
320
  if write_text:
321
+ cv2.putText(image_copy, text, (int(x1 + int(2 * scale)), int((y1 + y2) / 2)), cv2.FONT_HERSHEY_SIMPLEX, scale / 2, (0, 0, 0), 2)
322
+
323
  def draw_with_links(full_prediction):
324
+ """Draws links between objects based on the full prediction data."""
 
325
  if draw_twins and full_prediction is not None:
326
+ circle_color = (0, 255, 0)
327
+ circle_radius = int(10 * scale)
 
 
328
  for idx, (key1, key2) in enumerate(full_prediction['keypoints']):
329
  if full_prediction['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
330
+ list(model_dict.values()).index('messageFlow'),
331
+ list(model_dict.values()).index('dataAssociation')]:
332
  continue
 
333
  distance = np.linalg.norm(key1[:2] - key2[:2])
334
  if distance < 10:
335
+ x_new, y_new, x, y = find_other_keypoint(idx, full_prediction['keypoints'], full_prediction['boxes'])
336
  cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1)
337
+ cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0, 0, 0), -1)
338
 
339
+ if draw_links and full_prediction is not None:
 
340
  for i, (start_idx, end_idx) in enumerate(full_prediction['links']):
341
  if start_idx is None or end_idx is None:
342
  continue
343
  start_box = full_prediction['boxes'][start_idx]
344
  end_box = full_prediction['boxes'][end_idx]
345
  current_box = full_prediction['boxes'][i]
 
346
  start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2)
347
  end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2)
348
  current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2)
349
+ cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2 * scale))
350
+ cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2 * scale))
 
351
 
352
+ i += 1
353
 
 
354
  if target is not None:
355
  draw(target, is_prediction=False)
 
356
  if prediction is not None:
 
357
  draw(prediction, is_prediction=True)
 
358
  if full_prediction is not None:
359
  draw_with_links(full_prediction)
360
 
 
361
  image_copy = cv2.cvtColor(image_copy, cv2.COLOR_BGR2RGB)
362
  plt.figure(figsize=(12, 12))
363
  plt.imshow(image_copy)
364
+ if not axis:
365
  plt.axis('off')
366
  plt.show()
367
 
 
382
  closest_object_idx = None
383
  best_point = None
384
  min_distance = float('inf')
 
385
  for i, box in enumerate(boxes):
386
  if labels[i] in [list(class_dict.values()).index('sequenceFlow'),
387
  list(class_dict.values()).index('messageFlow'),
388
  list(class_dict.values()).index('dataAssociation'),
 
389
  list(class_dict.values()).index('lane')]:
390
  continue
391
  x1, y1, x2, y2 = box
392
 
393
+ top = ((x1 + x2) / 2, y1)
394
+ bottom = ((x1 + x2) / 2, y2)
395
+ left = (x1, (y1 + y2) / 2)
396
+ right = (x2, (y1 + y2) / 2)
397
+ points = [left, top, right, bottom]
398
 
399
+ pos_dict = {0: 'left', 1: 'top', 2: 'right', 3: 'bottom'}
400
 
401
+ for pos, point in enumerate(points):
 
402
  distance = np.linalg.norm(keypoint[:2] - point)
 
403
  if distance < min_distance:
404
  min_distance = distance
405
  closest_object_idx = i
 
407
 
408
  return closest_object_idx, best_point
409
 
 
410
  def error(text='There is an error in the detection'):
411
+ """Display an error message using Streamlit."""
412
  st.error(text, icon="🚨")
413
 
414
  def warning(text='Some element are maybe not detected, verify the results, try to modify the parameters or try to add it in the method and style step.'):
415
+ """Display a warning message using Streamlit."""
416
  st.warning(text, icon="⚠️")