Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| import cv2 | |
| import glob | |
| import os | |
| import streamlit as st | |
| import torch | |
| # model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5/runs/train/exp/weights/last.pt', skip_validation=True) | |
| model = torch.hub.load('ultralytics/yolov5', 'custom', path='data/model/exp/weights/last.pt', skip_validation=True) | |
| def detect_image(): | |
| img = os.path.join('data', 'images', 'good_shooting_form.d4c0bb30-ee73-11ed-ae90-4685b43730ba.jpg') | |
| results = model(img) | |
| results.print() | |
| return results.render() | |
| def image_input(data_src): | |
| img_file = None | |
| if data_src == 'Sample data': | |
| # get all sample images | |
| img_path = glob.glob('data/sample_images/*') | |
| img_slider = st.slider("Select a test image.", min_value=1, max_value=len(img_path), step=1) | |
| img_file = img_path[img_slider - 1] | |
| else: | |
| img_bytes = st.file_uploader("Upload an image", type=['png', 'jpeg', 'jpg']) | |
| if img_bytes: | |
| img_file = "data/uploaded_data/upload." + img_bytes.name.split('.')[-1] | |
| Image.open(img_bytes).save(img_file) | |
| if img_file: | |
| img = infer_image(img_file) | |
| st.image(img, caption="Model prediction") | |
| # if img_file: | |
| # col1, col2 = st.columns(2) | |
| # with col1: | |
| # st.image(img_file, caption="Selected Image") | |
| # with col2: | |
| # img = infer_image(img_file) | |
| # st.image(img, caption="Model prediction") | |
| def video_input(data_src): | |
| vid_file = None | |
| if data_src == 'Sample data': | |
| vid_file = os.path.join('data', 'sample_videos', 'demo.mp4') | |
| else: | |
| vid_bytes = st.file_uploader("Upload a video", type=['mp4', 'mpv', 'avi']) | |
| if vid_bytes: | |
| vid_file = "data/uploaded_data/upload." + vid_bytes.name.split('.')[-1] | |
| with open(vid_file, 'wb') as out: | |
| out.write(vid_bytes.read()) | |
| cap = cv2.VideoCapture(vid_file) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| st.markdown("---") | |
| output = st.empty() | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| st.write("Can't read frame. Exiting....") | |
| break | |
| frame = cv2.resize(frame, (width, height)) | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| output_img = infer_image(frame) | |
| output.image(output_img) | |
| # cv2.imshow('YOLO', np.squeeze(results.render())) | |
| if cv2.waitKey(10) & 0xFF == ord('q'): | |
| break | |
| cap.release() | |
| def infer_image(img): | |
| result = model(img) | |
| result.render() | |
| image = Image.fromarray(result.ims[0]) | |
| return image | |
| def main(): | |
| # input options | |
| input_option = st.radio("Select input type: ", ['Image', 'Video']) | |
| # input src option | |
| data_src = st.radio("Select input source: ", ['Sample data', 'Upload your own data']) | |
| if input_option == 'Image': | |
| image_input(data_src) | |
| else: | |
| video_input(data_src) | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except SystemExit: | |
| pass | |