Spaces:
Sleeping
Sleeping
| from ultralytics import YOLO | |
| from PIL import Image | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| import os | |
| import cv2 | |
| import tempfile | |
| def load_model(repo_id): | |
| download_dir = snapshot_download(repo_id) | |
| print(download_dir) | |
| path = os.path.join(download_dir, "best.pt") | |
| print(path) | |
| detection_model = YOLO(path, task='detect') | |
| return detection_model | |
| def predict(pilimg): | |
| if pilimg is None: | |
| return None | |
| source = pilimg | |
| # x = np.asarray(pilimg) | |
| # print(x.shape) | |
| result = detection_model.predict(source, conf=0.5, iou=0.6) | |
| img_bgr = result[0].plot() | |
| out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # RGB-order PIL image | |
| return out_pilimg | |
| def predict_video(video): | |
| if video is None: | |
| return None # Return None if no video was uploaded | |
| # Read video file using OpenCV (video is now a string, so we can directly pass it as a path) | |
| cap = cv2.VideoCapture(video) | |
| frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| temp_output_path = tempfile.mktemp(suffix=".mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(temp_output_path, fourcc, fps, (frame_width, frame_height)) | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| result = detection_model.predict(frame, conf=0.5, iou=0.6) | |
| img_bgr = result[0].plot() | |
| out_frame = img_bgr[..., ::-1] # Convert BGR to RGB | |
| out.write(out_frame) | |
| cap.release() | |
| out.release() | |
| # Return the path to the processed video | |
| return temp_output_path # Return the path to the processed video | |
| def enable_button(image_input, video_input): | |
| if image_input is None and video_input is None: | |
| return gr.Button.update(interactive=False) | |
| return gr.Button.update(interactive=True) | |
| REPO_ID = "dexpyw/model" | |
| detection_model = load_model(REPO_ID) | |
| image_interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=gr.Image(type="pil", label="Predicted Image"), | |
| live=False | |
| ) | |
| video_interface = gr.Interface( | |
| fn=predict_video, | |
| inputs=gr.Video(label="Upload Video"), | |
| outputs=gr.Video(label="Predicted Video"), | |
| live=False | |
| ) | |
| gr.TabbedInterface([image_interface, video_interface], ["Image", "Video"]).launch(share=True) | |
| # image_interface.launch(share=True) | |
| # video_interface.launch(share=True) |