isyourshotbroke / app.py
web3slinger's picture
first commit
20873ac
from PIL import Image
import cv2
import glob
import os
import streamlit as st
import torch
# model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5/runs/train/exp/weights/last.pt', skip_validation=True)
model = torch.hub.load('ultralytics/yolov5', 'custom', path='data/model/exp/weights/last.pt', skip_validation=True)
def detect_image():
img = os.path.join('data', 'images', 'good_shooting_form.d4c0bb30-ee73-11ed-ae90-4685b43730ba.jpg')
results = model(img)
results.print()
return results.render()
def image_input(data_src):
img_file = None
if data_src == 'Sample data':
# get all sample images
img_path = glob.glob('data/sample_images/*')
img_slider = st.slider("Select a test image.", min_value=1, max_value=len(img_path), step=1)
img_file = img_path[img_slider - 1]
else:
img_bytes = st.file_uploader("Upload an image", type=['png', 'jpeg', 'jpg'])
if img_bytes:
img_file = "data/uploaded_data/upload." + img_bytes.name.split('.')[-1]
Image.open(img_bytes).save(img_file)
if img_file:
img = infer_image(img_file)
st.image(img, caption="Model prediction")
# if img_file:
# col1, col2 = st.columns(2)
# with col1:
# st.image(img_file, caption="Selected Image")
# with col2:
# img = infer_image(img_file)
# st.image(img, caption="Model prediction")
def video_input(data_src):
vid_file = None
if data_src == 'Sample data':
vid_file = os.path.join('data', 'sample_videos', 'demo.mp4')
else:
vid_bytes = st.file_uploader("Upload a video", type=['mp4', 'mpv', 'avi'])
if vid_bytes:
vid_file = "data/uploaded_data/upload." + vid_bytes.name.split('.')[-1]
with open(vid_file, 'wb') as out:
out.write(vid_bytes.read())
cap = cv2.VideoCapture(vid_file)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
st.markdown("---")
output = st.empty()
while True:
ret, frame = cap.read()
if not ret:
st.write("Can't read frame. Exiting....")
break
frame = cv2.resize(frame, (width, height))
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
output_img = infer_image(frame)
output.image(output_img)
# cv2.imshow('YOLO', np.squeeze(results.render()))
if cv2.waitKey(10) & 0xFF == ord('q'):
break
cap.release()
def infer_image(img):
result = model(img)
result.render()
image = Image.fromarray(result.ims[0])
return image
def main():
# input options
input_option = st.radio("Select input type: ", ['Image', 'Video'])
# input src option
data_src = st.radio("Select input source: ", ['Sample data', 'Upload your own data'])
if input_option == 'Image':
image_input(data_src)
else:
video_input(data_src)
if __name__ == "__main__":
try:
main()
except SystemExit:
pass