from ultralytics import YOLO from PIL import Image import gradio as gr from huggingface_hub import snapshot_download import os import cv2 import numpy as np from tqdm import tqdm import tempfile # Function to load the model def load_model(repo_id): """Download and load the YOLO model.""" download_dir = snapshot_download(repo_id) path = os.path.join(download_dir, "best_int8_openvino_model") detection_model = YOLO(path, task="detect") return detection_model # Function to process an image def predict_image(pilimg, conf_threshold, iou_threshold): """Process an image with user-defined thresholds.""" try: result = detection_model.predict(pilimg, conf=conf_threshold, iou=iou_threshold) img_bgr = result[0].plot() out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # Convert to RGB PIL image return out_pilimg except Exception as e: return f"Error processing image: {e}" # Function to process a video def predict_video(video_file, conf_threshold, iou_threshold, start_time, end_time): """Process a video and return the path for displaying.""" cap = cv2.VideoCapture(video_file) if not cap.isOpened(): return "Error: Unable to open the video file." total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Use a temporary file to store the processed video temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) output_path = temp_video_file.name fourcc = cv2.VideoWriter_fourcc(*'mp4v') start_frame = int(start_time * fps) if start_time else 0 end_frame = int(end_time * fps) if end_time else total_frames cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height)) with tqdm(total=end_frame - start_frame, desc="Processing Video") as pbar: while cap.isOpened(): current_frame = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) if current_frame >= end_frame: break ret, frame = cap.read() if not ret: break resized_frame = cv2.resize(frame, (640, 640)) # Resize for inference result = detection_model.predict(resized_frame, conf=conf_threshold, iou=iou_threshold) output_frame = result[0].plot() output_frame = cv2.resize(output_frame, (frame_width, frame_height)) # Restore size out.write(output_frame) pbar.update(1) cap.release() out.release() return output_path # Load YOLO model REPO_ID = "Ganrong/107project" detection_model = load_model(REPO_ID) # Gradio UI with gr.Blocks() as demo: gr.Markdown("## Pangolin and Axolotl Detection") # Image Processing Tab with gr.Tab("Image Input"): img_input = gr.Image(type="pil", label="Upload an Image") conf_slider_img = gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold") iou_slider_img = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold") img_output = gr.Image(type="pil", label="Processed Image") img_submit = gr.Button("Process Image") img_submit.click( predict_image, inputs=[img_input, conf_slider_img, iou_slider_img], outputs=img_output ) # Video Processing Tab with gr.Tab("Video Input"): video_input = gr.Video(label="Upload a Video") conf_slider_video = gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold") iou_slider_video = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold") start_time = gr.Number(value=0, label="Start Time (seconds)") end_time = gr.Number(value=0, label="End Time (seconds, 0 for full video)") video_output = gr.Video(label="Processed Video") video_submit = gr.Button("Process Video") video_submit.click( predict_video, inputs=[video_input, conf_slider_video, iou_slider_video, start_time, end_time], outputs=video_output ) demo.launch(share=True)