Daniel Cerda Escobar
Upgrade plot
62438b8
raw
history blame
7.09 kB
import streamlit as st
import pandas as pd
import plotly.express as px
import sahi.utils.file
from PIL import Image
from sahi import AutoDetectionModel
from utils import sahi_yolov8m_inference
from ultralyticsplus.hf_utils import download_from_hub
IMAGE_TO_URL = {
'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png',
'plant_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png',
'processing_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/processing-pid.png',
'prediction_visual.png' : 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png'
}
st.set_page_config(
page_title="P&ID Object Detection",
layout="wide",
initial_sidebar_state="expanded"
)
st.title('P&ID Object Detection')
st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow')
st.markdown(
"""
<a href='https://cl.linkedin.com/in/daniel-cerda-escobar' target='_blank'><img src="https://img.icons8.com/fluency/48/000000/linkedin.png" height="30"></a>
</p>
""",
unsafe_allow_html=True,
)
@st.cache_resource(show_spinner=False)
def get_model(postprocess_match_threshold):
yolov8_model_path = download_from_hub('DanielCerda/pid_yolov8')
detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=yolov8_model_path,
confidence_threshold=postprocess_match_threshold,
device="cpu",
)
return detection_model
@st.cache_data(show_spinner=False)
def download_comparison_images():
sahi.utils.file.download_from_url(
'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png',
'plant_pid.png',
)
sahi.utils.file.download_from_url(
'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png',
'prediction_visual.png',
)
download_comparison_images()
# initialize prediction visual data
coco_df = pd.DataFrame({
'category' : ['centrifugal-pump','centrifugal-pump','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve'],
'score' : [0.88, 0.85, 0.87, 0.87, 0.86, 0.86, 0.85, 0.84, 0.81, 0.81, 0.76]
})
output_df = pd.DataFrame({
'category':['ball-valve', 'butterfly-valve', 'centrifugal-pump', 'check-valve', 'gate-valve'],
'count':[0, 0, 2, 0, 9],
'percentage':[0, 0, 18.2, 0, 81.8]
})
# session state
if "output_1" not in st.session_state:
img_1 = Image.open('plant_pid.png')
st.session_state["output_1"] = img_1.resize((4960,3508))
if "output_2" not in st.session_state:
img_2 = Image.open('prediction_visual.png')
st.session_state["output_2"] = img_2.resize((4960,3508))
if "output_3" not in st.session_state:
st.session_state["output_3"] = coco_df
if "output_4" not in st.session_state:
st.session_state["output_4"] = output_df
col1, col2, col3 = st.columns(3, gap='medium')
with col1:
with st.expander('How to use it'):
st.markdown(
'''
1) Upload or select any example diagram πŸ‘†πŸ»
2) Set model parameters πŸ“ˆ
3) Press to perform inference πŸš€
4) Visualize model predictions πŸ”Ž
'''
)
st.write('##')
col1, col2, col3 = st.columns(3, gap='large')
with col1:
st.markdown('##### Set Input Image')
# set input image by upload
image_file = st.file_uploader(
'Upload your P&ID', type = ['jpg','jpeg','png']
)
# set input images from examples
def radio_func(option):
option_to_id = {
'factory_pid.png' : 'A',
'plant_pid.png' : 'B',
'processing_pid.png' : 'C',
}
return option_to_id[option]
radio = st.radio(
'Select from the following examples',
options = ['factory_pid.png', 'plant_pid.png', 'processing_pid.png'],
format_func = radio_func,
)
with col2:
# visualize input image
if image_file is not None:
image = Image.open(image_file)
else:
image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio])
st.markdown('##### Preview')
with st.container(border = True):
st.image(image, use_column_width = True)
with col3:
# set SAHI parameters
st.markdown('##### Set model parameters')
slice_number = st.select_slider(
'Slices per Image',
options = [
'1',
'4',
'16',
'64',
],
value = '4'
)
overlap_ratio = st.slider(
label = 'Slicing Overlap Ratio',
min_value=0.0,
max_value=0.5,
value=0.1,
step=0.1
)
postprocess_match_threshold = st.slider(
label = 'Confidence Threshold',
min_value = 0.0,
max_value = 1.0,
value = 0.85,
step = 0.05
)
st.write('##')
col1, col2, col3 = st.columns([4, 1, 4])
with col2:
submit = st.button("πŸš€ Perform Prediction")
if submit:
# perform prediction
with st.spinner(text="Downloading model weights ... "):
detection_model = get_model(postprocess_match_threshold)
slice_size = int(4960/(float(slice_number)**0.5))
image_size = 4960
with st.spinner(text="Performing prediction ... "):
output_visual,coco_df,output_df = sahi_yolov8m_inference(
image,
detection_model,
image_size=image_size,
slice_height=slice_size,
slice_width=slice_size,
overlap_height_ratio=overlap_ratio,
overlap_width_ratio=overlap_ratio,
)
st.session_state["output_1"] = image
st.session_state["output_2"] = output_visual
st.session_state["output_3"] = coco_df
st.session_state["output_4"] = output_df
st.write('##')
col1, col2, col3 = st.columns([1, 5, 1], gap='small')
with col2:
st.markdown(f"#### Object Detection Result")
with st.container(border = True):
tab1, tab2, tab3, tab4 = st.tabs(['Original Image','Inference Prediction','Data','Insights'])
with tab1:
st.image(st.session_state["output_1"])
with tab2:
st.image(st.session_state["output_2"])
with tab3:
col1,col2,col3 = st.columns([1,2,1])
with col2:
st.dataframe(
st.session_state["output_3"],
column_config = {
'category' : 'Predicted Category',
'score' : 'Confidence',
},
use_container_width = True,
hide_index = True,
)
with tab4:
col1,col2,col3 = st.columns([1,5,1])
with col2:
chart_data = st.session_state["output_4"]
fig = px.bar(chart_data, x='category', y='count', color='category')
fig.update_layout(xaxis_title=None, yaxis_title=None, showlegend=False,yaxis=dict(tick0=0,dtick=1),bargap=0.5)
st.plotly_chart(fig,use_container_width=True, theme='streamlit' )