sketch-to-BPMN / demo_streamlit.py
BenjiELCA's picture
put online demo
615e9f1
raw
history blame
14.7 kB
import streamlit as st
import streamlit.components.v1 as components
from PIL import Image
import torch
from torchvision.transforms import functional as F
from PIL import Image, ImageEnhance
from htlm_webpage import display_bpmn_xml
import gc
import psutil
from OCR import text_prediction, filter_text, mapping_text, rescale
from train import prepare_model
from utils import draw_annotations, create_loader, class_dict, arrow_dict, object_dict
from toXML import calculate_pool_bounds, add_diagram_elements
from pathlib import Path
from toXML import create_bpmn_object, create_flow_element
import xml.etree.ElementTree as ET
import numpy as np
from display import draw_stream
from eval import full_prediction
from streamlit_image_comparison import image_comparison
from xml.dom import minidom
from streamlit_cropper import st_cropper
from streamlit_drawable_canvas import st_canvas
from utils import find_closest_object
from train import get_faster_rcnn_model, get_arrow_model
import gdown
def get_memory_usage():
process = psutil.Process()
mem_info = process.memory_info()
return mem_info.rss / (1024 ** 2) # Return memory usage in MB
def clear_memory():
st.session_state.clear()
gc.collect()
# Function to read XML content from a file
def read_xml_file(filepath):
""" Read XML content from a file """
with open(filepath, 'r', encoding='utf-8') as file:
return file.read()
# Function to modify bounding box positions based on the given sizes
def modif_box_pos(pred, size):
for i, (x1, y1, x2, y2) in enumerate(pred['boxes']):
center = [(x1 + x2) / 2, (y1 + y2) / 2]
label = class_dict[pred['labels'][i]]
if label in size:
pred['boxes'][i] = [center[0] - size[label][0] / 2, center[1] - size[label][1] / 2, center[0] + size[label][0] / 2, center[1] + size[label][1] / 2]
return pred
# Function to create a BPMN XML file from prediction results
def create_XML(full_pred, text_mapping, scale):
namespaces = {
'bpmn': 'http://www.omg.org/spec/BPMN/20100524/MODEL',
'bpmndi': 'http://www.omg.org/spec/BPMN/20100524/DI',
'di': 'http://www.omg.org/spec/DD/20100524/DI',
'dc': 'http://www.omg.org/spec/DD/20100524/DC',
'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
}
size_elements = {
'start': (54, 54),
'task': (150, 120),
'message': (54, 54),
'messageEvent': (54, 54),
'end': (54, 54),
'exclusiveGateway': (75, 75),
'event': (54, 54),
'parallelGateway': (75, 75),
'sequenceFlow': (225, 15),
'pool': (375, 150),
'lane': (300, 150),
'dataObject': (60, 90),
'dataStore': (90, 90),
'subProcess': (180, 135),
'eventBasedGateway': (75, 75),
'timerEvent': (60, 60),
}
definitions = ET.Element('bpmn:definitions', {
'xmlns:xsi': namespaces['xsi'],
'xmlns:bpmn': namespaces['bpmn'],
'xmlns:bpmndi': namespaces['bpmndi'],
'xmlns:di': namespaces['di'],
'xmlns:dc': namespaces['dc'],
'targetNamespace': "http://example.bpmn.com",
'id': "simpleExample"
})
# Create BPMN collaboration element
collaboration = ET.SubElement(definitions, 'bpmn:collaboration', id='collaboration_1')
# Create BPMN process elements
process = []
for idx in range(len(full_pred['pool_dict'].items())):
process_id = f'process_{idx+1}'
process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]]))
bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1')
bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1')
full_pred['boxes'] = rescale(scale, full_pred['boxes'])
# Add diagram elements for each pool
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
pool_id = f'participant_{idx+1}'
pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]])
# Calculate the bounding box for the pool
if len(keep_elements) == 0:
min_x, min_y, max_x, max_y = full_pred['boxes'][pool_index]
pool_width = max_x - min_x
pool_height = max_y - min_y
else:
min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred, keep_elements, size_elements)
pool_width = max_x - min_x + 100 # Adding padding
pool_height = max_y - min_y + 100 # Adding padding
add_diagram_elements(bpmnplane, pool_id, min_x - 50, min_y - 50, pool_width, pool_height)
# Create BPMN elements for each pool
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements)
# Create message flow elements
message_flows = [i for i, label in enumerate(full_pred['labels']) if class_dict[label] == 'messageFlow']
for idx in message_flows:
create_flow_element(bpmnplane, text_mapping, idx, size_elements, full_pred, collaboration, message=True)
# Create sequence flow elements
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
for i in keep_elements:
if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'):
create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False)
# Generate pretty XML string
tree = ET.ElementTree(definitions)
rough_string = ET.tostring(definitions, 'utf-8')
reparsed = minidom.parseString(rough_string)
pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
full_pred['boxes'] = rescale(1/scale, full_pred['boxes'])
return pretty_xml_as_string
# Function to load the models only once and use session state to keep track of it
def load_models():
with st.spinner('Loading model...'):
model_object = get_faster_rcnn_model(len(object_dict))
model_arrow = get_arrow_model(len(arrow_dict),2)
url_arrow = 'https://drive.google.com/uc?id=1xwfvo7BgDWz-1jAiJC1DCF0Wp8YlFNWt'
url_object = 'https://drive.google.com/uc?id=1GiM8xOXG6M6r8J9HTOeMJz9NKu7iumZi'
# Define paths to save models
output_arrow = 'model_arrow.pth'
output_object = 'model_object.pth'
# Download models using gdown
if not Path(output_arrow).exists():
# Download models using gdown
gdown.download(url_arrow, output_arrow, quiet=False)
else:
print('Model arrow downloaded from local')
if not Path(output_object).exists():
gdown.download(url_object, output_object, quiet=False)
else:
print('Model object downloaded from local')
# Load models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_arrow.load_state_dict(torch.load(output_arrow, map_location=device))
model_object.load_state_dict(torch.load(output_object, map_location=device))
st.session_state.model_loaded = True
st.session_state.model_arrow = model_arrow
st.session_state.model_object = model_object
# Function to prepare the image for processing
def prepare_image(image, pad=True, new_size=(1333, 1333)):
original_size = image.size
# Calculate scale to fit the new size while maintaining aspect ratio
scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale))
# Resize image to new scaled size
image = F.resize(image, (new_scaled_size[1], new_scaled_size[0]))
if pad:
enhancer = ImageEnhance.Brightness(image)
image = enhancer.enhance(1.5) # Adjust the brightness if necessary
# Pad the resized image to make it exactly the desired size
padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]]
image = F.pad(image, padding, fill=200, padding_mode='edge')
return new_scaled_size, image
# Function to display various options for image annotation
def display_options(image, score_threshold):
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
write_class = st.toggle("Write Class", value=True)
draw_keypoints = st.toggle("Draw Keypoints", value=True)
draw_boxes = st.toggle("Draw Boxes", value=True)
with col2:
draw_text = st.toggle("Draw Text", value=False)
write_text = st.toggle("Write Text", value=False)
draw_links = st.toggle("Draw Links", value=False)
with col3:
write_score = st.toggle("Write Score", value=True)
write_idx = st.toggle("Write Index", value=False)
with col4:
# Define options for the dropdown menu
dropdown_options = [list(class_dict.values())[i] for i in range(len(class_dict))]
dropdown_options[0] = 'all'
selected_option = st.selectbox("Show class", dropdown_options)
# Draw the annotated image with selected options
annotated_image = draw_stream(
np.array(image), prediction=st.session_state.prediction, text_predictions=st.session_state.text_pred,
draw_keypoints=draw_keypoints, draw_boxes=draw_boxes, draw_links=draw_links, draw_twins=False, draw_grouped_text=draw_text,
write_class=write_class, write_text=write_text, keypoints_correction=True, write_idx=write_idx, only_print=selected_option,
score_threshold=score_threshold, write_score=write_score, resize=True, return_image=True, axis=True
)
# Display the original and annotated images side by side
image_comparison(
img1=annotated_image,
img2=image,
label1="Annotated Image",
label2="Original Image",
starting_position=99,
width=1000,
)
# Function to perform inference on the uploaded image using the loaded models
def perform_inference(model_object, model_arrow, image, score_threshold):
_, uploaded_image = prepare_image(image, pad=False)
img_tensor = F.to_tensor(prepare_image(image.convert('RGB'))[1])
# Display original image
if 'image_placeholder' not in st.session_state:
image_placeholder = st.empty() # Create an empty placeholder
image_placeholder.image(uploaded_image, caption='Original Image', width=1000)
# Prediction
_, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=0.5)
# Perform OCR on the uploaded image
ocr_results = text_prediction(uploaded_image)
# Filter and map OCR results to prediction results
st.session_state.text_pred = filter_text(ocr_results, threshold=0.5)
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=0.5)
# Remove the original image display
image_placeholder.empty()
# Force garbage collection
gc.collect()
@st.cache_data
def get_image(uploaded_file):
return Image.open(uploaded_file).convert('RGB')
def main():
st.set_page_config(layout="wide")
st.title("BPMN model recognition demo")
# Display current memory usage
memory_usage = get_memory_usage()
print(f"Current memory usage: {memory_usage:.2f} MB")
# Initialize the session state for storing pool bounding boxes
if 'pool_bboxes' not in st.session_state:
st.session_state.pool_bboxes = []
# Load the models using the defined function
if 'model_object' not in st.session_state or 'model_arrow' not in st.session_state:
clear_memory()
load_models()
model_arrow = st.session_state.model_arrow
model_object = st.session_state.model_object
#Create the layout for the app
col1, col2 = st.columns(2)
with col1:
# Create a file uploader for the user to upload an image
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
# Display the uploaded image if the user has uploaded an image
if uploaded_file is not None:
original_image = get_image(uploaded_file)
col1, col2 = st.columns(2)
# Create a cropper to allow the user to crop the image and display the cropped image
with col1:
cropped_image = st_cropper(original_image, realtime_update=True, box_color='#0000FF', should_resize_image=True, default_coords=(30, original_image.size[0]-30, 30, original_image.size[1]-30))
with col2:
st.image(cropped_image, caption="Cropped Image", use_column_width=False, width=500)
# Display the options for the user to set the score threshold and scale
if cropped_image is not None:
col1, col2, col3 = st.columns(3)
with col1:
score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
with col2:
st.session_state.scale = st.slider("Set scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
# Launch the prediction when the user clicks the button
if st.button("Launch Prediction"):
st.session_state.crop_image = cropped_image
with st.spinner('Processing...'):
perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold)
st.session_state.prediction = modif_box_pos(st.session_state.prediction, object_dict)
print('Detection completed!')
# If the prediction has been made and the user has uploaded an image, display the options for the user to annotate the image
if 'prediction' in st.session_state and uploaded_file is not None:
st.success('Detection completed!')
display_options(st.session_state.crop_image, score_threshold)
#if st.session_state.prediction_up==True:
st.session_state.bpmn_xml = create_XML(st.session_state.prediction.copy(), st.session_state.text_mapping, st.session_state.scale)
display_bpmn_xml(st.session_state.bpmn_xml)
# Force garbage collection after display
gc.collect()
if __name__ == "__main__":
print('Starting the app...')
main()