demos / app.py
ch-tseng's picture
'update'
3acfc50
raw
history blame contribute delete
No virus
4.76 kB
# -*- coding: utf-8 -*-
HuggingFace = True
update_model_id = None
base_download_path = "models"
from pathlib import Path
from PIL import Image
import streamlit as st
import config
from utils import load_model, infer_uploaded_image, infer_uploaded_video, infer_uploaded_webcam
import os
# setting page layout
st.set_page_config(
page_title="YOLO.dog",
page_icon="🤖",
layout="wide",
initial_sidebar_state="expanded"
)
mnames = {
'cat_body_parts':['.pt','Cat_Body_Parts'],
'road_defects':['.pt','Road_Defects'],
'13_animals':['.pt','13-Kinds-Animals'],
'digger':['.pt','Digger'],
'fire_smoke':['.pt','Fire_Smoke'],
'gun_knife':['.pt','Gun-Knife'],
'Type_of_car':['.pt','Car_Types'],
'working_security':['.pt', 'Working-Safety'],
'crowded_human':['.pt','Crowded-Human'],
'face_mask_eyeballs':['.pt', 'Face_Mask_Eyeballs'],
'forklift':['.pt','Forklift'],
'Human_Palm':['.pt','Human_Palm'],
'vehicles_plates':['.pt','Vehicles_Plates'],
'like_dislike':['.pt','Like_Dislike'],
'yolo_gender':['.pt','Human_Gender'] }
query_params = st.experimental_get_query_params()
qmodel = 'Crowded-Human'
if 'model' in query_params:
qmodel = query_params['model'][0]
#------------------------
if not os.path.exists(base_download_path):
try:
os.makedirs(base_download_path)
except:
print("folder exists", base_download_path)
pass
if HuggingFace is False:
model_count = int(os.getenv("model_count"))
else:
model_count = int(st.secrets["model_count"])
model_info = {}
models_list = []
#for model_name in mnames:
for i in range(0, model_count):
if HuggingFace is False:
model_name = os.getenv("m{}_name".format(i))
model_extname = os.getenv("m{}_type".format(i))
model_desc = os.getenv("m{}_desc".format(i))
model_url = os.getenv("m{}_url".format(i))
else:
model_name = st.secrets["m{}_name".format(i)]
model_extname = st.secrets["m{}_type".format(i)]
model_desc = st.secrets["m{}_desc".format(i)]
model_url = st.secrets["m{}_url".format(i)]
#model_extname = mnames[model_name][0]
#model_desc = mnames[model_name][1]
path_model = os.path.join(base_download_path, model_name + model_extname)
print('path_model', path_model)
model_info.update( {model_desc:path_model} )
models_list.append(model_desc)
if not os.path.exists(path_model):
download_link = model_url
#download_link = "https://drive.google.com/file/d/{}/view?usp=sharing".format(gdrive_id)
#print('wget -O {} --content-disposition "{}"'.format(path_model, download_link))
#os.system( 'wget -O {} --content-disposition "{}"'.format(path_model, download_link))
print('wget {} '.format(download_link))
os.system( 'wget -O {} --content-disposition "{}"'.format(path_model, download_link))
#else:
# download_file_from_google_drive(gdrive_id, path_model)
#print('models_list', models_list)
if qmodel not in models_list:
qmodel = models_list[0]
# main page heading
#st.title( model_info[qmodel] )
# sidebar
st.sidebar.header("Model Config")
# model options
task_type = "Detection"
model_type = st.sidebar.selectbox(
"Models list",
models_list,
index=models_list.index(qmodel) )
if model_type:
st.header('{} Model Trial'.format(model_type),divider='rainbow')
st.subheader('Use your :blue[photo/video] :sunglasses:')
model_path = model_info[model_type]
#try:
print('model_path', model_path)
model = load_model(model_path)
#except Exception as e:
# st.error(f"Unable to load model. Please check the specified path: {model_path}")
else:
st.error("Please Select Model in Sidebar")
# image/video options
#st.sidebar.header("Image/Video Config")
#source_selectbox = st.sidebar.selectbox(
# "Select Source",
# config.SOURCES_LIST
#)
source_selectbox = "Image"
confidence = float(st.sidebar.slider(
"Adjust Model Confidence", 5, 100, 25)) / 100
drawBox = st.sidebar.checkbox("Draw box", value=True)
drawLabel = st.sidebar.checkbox("disply label", value=True)
drawScore = st.sidebar.checkbox("disply confidence", value=True)
source_img = None
if source_selectbox == config.SOURCES_LIST[0]: # Image
infer_uploaded_image(confidence, drawBox, drawLabel, drawScore, model)
elif source_selectbox == config.SOURCES_LIST[1]: # Video
infer_uploaded_video(confidence, model)
elif source_selectbox == config.SOURCES_LIST[2]: # Webcam
infer_uploaded_webcam(confidence, model)
else:
st.error("Currently only 'Image' and 'Video' source are implemented")