Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import plotly.express as px | |
import sahi.utils.file | |
from PIL import Image | |
from sahi import AutoDetectionModel | |
from utils import sahi_yolov8m_inference | |
from ultralyticsplus.hf_utils import download_from_hub | |
IMAGE_TO_URL = { | |
'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png', | |
'plant_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png', | |
'processing_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/processing-pid.png', | |
'prediction_visual.png' : 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png' | |
} | |
st.set_page_config( | |
page_title="P&ID Object Detection", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
st.title('P&ID Object Detection') | |
st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow') | |
st.markdown( | |
""" | |
<a href='https://cl.linkedin.com/in/daniel-cerda-escobar' target='_blank'><img src="https://img.icons8.com/fluency/48/000000/linkedin.png" height="30"></a> | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
def get_model(postprocess_match_threshold): | |
yolov8_model_path = download_from_hub('DanielCerda/pid_yolov8') | |
detection_model = AutoDetectionModel.from_pretrained( | |
model_type='yolov8', | |
model_path=yolov8_model_path, | |
confidence_threshold=postprocess_match_threshold, | |
device="cpu", | |
) | |
return detection_model | |
def download_comparison_images(): | |
sahi.utils.file.download_from_url( | |
'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png', | |
'plant_pid.png', | |
) | |
sahi.utils.file.download_from_url( | |
'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png', | |
'prediction_visual.png', | |
) | |
download_comparison_images() | |
# initialize prediction visual data | |
coco_df = pd.DataFrame({ | |
'category' : ['centrifugal-pump','centrifugal-pump','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve'], | |
'score' : [0.88, 0.85, 0.87, 0.87, 0.86, 0.86, 0.85, 0.84, 0.81, 0.81, 0.76] | |
}) | |
output_df = pd.DataFrame({ | |
'category':['ball-valve', 'butterfly-valve', 'centrifugal-pump', 'check-valve', 'gate-valve'], | |
'count':[0, 0, 2, 0, 9], | |
'percentage':[0, 0, 18.2, 0, 81.8] | |
}) | |
# session state | |
if "output_1" not in st.session_state: | |
img_1 = Image.open('plant_pid.png') | |
st.session_state["output_1"] = img_1.resize((4960,3508)) | |
if "output_2" not in st.session_state: | |
img_2 = Image.open('prediction_visual.png') | |
st.session_state["output_2"] = img_2.resize((4960,3508)) | |
if "output_3" not in st.session_state: | |
st.session_state["output_3"] = coco_df | |
if "output_4" not in st.session_state: | |
st.session_state["output_4"] = output_df | |
col1, col2, col3 = st.columns(3, gap='medium') | |
with col1: | |
with st.expander('How to use it'): | |
st.markdown( | |
''' | |
1) Upload or select any example diagram ππ» | |
2) Set model parameters π | |
3) Press to perform inference π | |
4) Visualize model predictions π | |
''' | |
) | |
st.write('##') | |
col1, col2, col3 = st.columns(3, gap='large') | |
with col1: | |
st.markdown('##### Set Input Image') | |
# set input image by upload | |
image_file = st.file_uploader( | |
'Upload your P&ID', type = ['jpg','jpeg','png'] | |
) | |
# set input images from examples | |
def radio_func(option): | |
option_to_id = { | |
'factory_pid.png' : 'A', | |
'plant_pid.png' : 'B', | |
'processing_pid.png' : 'C', | |
} | |
return option_to_id[option] | |
radio = st.radio( | |
'Select from the following examples', | |
options = ['factory_pid.png', 'plant_pid.png', 'processing_pid.png'], | |
format_func = radio_func, | |
) | |
with col2: | |
# visualize input image | |
if image_file is not None: | |
image = Image.open(image_file) | |
else: | |
image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio]) | |
st.markdown('##### Preview') | |
with st.container(border = True): | |
st.image(image, use_column_width = True) | |
with col3: | |
# set SAHI parameters | |
st.markdown('##### Set model parameters') | |
slice_number = st.select_slider( | |
'Slices per Image', | |
options = [ | |
'1', | |
'4', | |
'16', | |
'64', | |
], | |
value = '4' | |
) | |
overlap_ratio = st.slider( | |
label = 'Slicing Overlap Ratio', | |
min_value=0.0, | |
max_value=0.5, | |
value=0.1, | |
step=0.1 | |
) | |
postprocess_match_threshold = st.slider( | |
label = 'Confidence Threshold', | |
min_value = 0.0, | |
max_value = 1.0, | |
value = 0.85, | |
step = 0.05 | |
) | |
st.write('##') | |
col1, col2, col3 = st.columns([4, 1, 4]) | |
with col2: | |
submit = st.button("π Perform Prediction") | |
if submit: | |
# perform prediction | |
with st.spinner(text="Downloading model weights ... "): | |
detection_model = get_model(postprocess_match_threshold) | |
slice_size = int(4960/(float(slice_number)**0.5)) | |
image_size = 4960 | |
with st.spinner(text="Performing prediction ... "): | |
output_visual,coco_df,output_df = sahi_yolov8m_inference( | |
image, | |
detection_model, | |
image_size=image_size, | |
slice_height=slice_size, | |
slice_width=slice_size, | |
overlap_height_ratio=overlap_ratio, | |
overlap_width_ratio=overlap_ratio, | |
) | |
st.session_state["output_1"] = image | |
st.session_state["output_2"] = output_visual | |
st.session_state["output_3"] = coco_df | |
st.session_state["output_4"] = output_df | |
st.write('##') | |
col1, col2, col3 = st.columns([1, 5, 1], gap='small') | |
with col2: | |
st.markdown(f"#### Object Detection Result") | |
with st.container(border = True): | |
tab1, tab2, tab3, tab4 = st.tabs(['Original Image','Inference Prediction','Data','Insights']) | |
with tab1: | |
st.image(st.session_state["output_1"]) | |
with tab2: | |
st.image(st.session_state["output_2"]) | |
with tab3: | |
col1,col2,col3 = st.columns([1,2,1]) | |
with col2: | |
st.dataframe( | |
st.session_state["output_3"], | |
column_config = { | |
'category' : 'Predicted Category', | |
'score' : 'Confidence', | |
}, | |
use_container_width = True, | |
hide_index = True, | |
) | |
with tab4: | |
col1,col2,col3 = st.columns([1,5,1]) | |
with col2: | |
chart_data = st.session_state["output_4"] | |
fig = px.bar(chart_data, x='category', y='count', color='category') | |
fig.update_layout(xaxis_title=None, yaxis_title=None, showlegend=False,yaxis=dict(tick0=0,dtick=1),bargap=0.5) | |
st.plotly_chart(fig,use_container_width=True, theme='streamlit' ) | |