Spaces:
Sleeping
Sleeping
| # Import necessary libraries | |
| import gradio as gr # Gradio for building the web interface | |
| from ultralytics import YOLO # YOLO model for object detection | |
| from PIL import Image # PIL for image processing | |
| import os # OS module for file operations | |
| import cv2 # OpenCV for video processing | |
| import tempfile # Temporary file creation | |
| # Load the pre-trained YOLOv8 model | |
| model = YOLO('./best_yolo8_model/best.pt') # Path to the trained YOLO model | |
| # Define fixed dimensions for consistent display of input/output | |
| image_height = 300 # Height of the image display | |
| image_width = 400 # Width of the image display | |
| video_height = 300 # Height of the video display | |
| video_width = 400 # Width of the video display | |
| # Function to process an uploaded image | |
| def process_image(image): | |
| """ | |
| Perform object detection on an uploaded image and return the processed image. | |
| Args: | |
| image (PIL.Image): Input image uploaded by the user. | |
| Returns: | |
| PIL.Image: Processed image with detection results. | |
| """ | |
| results = model.predict(source=image, conf=0.5) # Perform inference with confidence threshold | |
| result_img = results[0].plot() # Plot the detection results on the image | |
| # Convert the processed image from BGR to RGB to maintain color consistency | |
| return Image.fromarray(cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)) | |
| # Function to process an uploaded video | |
| def process_video(video_path): | |
| """ | |
| Perform object detection on an uploaded video and return the processed video. | |
| Args: | |
| video_path (str): Path to the uploaded video file. | |
| Returns: | |
| str: Path to the processed video file. | |
| """ | |
| temp_dir = tempfile.mkdtemp() # Create a temporary directory for the output video | |
| output_video_path = os.path.join(temp_dir, "processed_video.mp4") # Path to save the processed video | |
| cap = cv2.VideoCapture(video_path) # Open the video file | |
| frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # Get the video frame width | |
| frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Get the video frame height | |
| fps = cap.get(cv2.CAP_PROP_FPS) # Get the video frame rate | |
| # Define the codec and create a VideoWriter object | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') # MP4 codec | |
| out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height)) | |
| # Process each frame of the video | |
| while cap.isOpened(): | |
| ret, frame = cap.read() # Read a frame | |
| if not ret: # Break the loop if no frame is read | |
| break | |
| results = model.predict(source=frame, conf=0.5) # Perform inference on the frame | |
| result_frame = results[0].plot() # Plot detection results on the frame | |
| # Correct the color format of the frame before saving | |
| result_frame_corrected = cv2.cvtColor(result_frame, cv2.COLOR_BGR2RGB) | |
| out.write(cv2.cvtColor(result_frame_corrected, cv2.COLOR_RGB2BGR)) # Write the frame to the output video | |
| cap.release() # Release the video capture object | |
| out.release() # Release the video writer object | |
| return output_video_path # Return the path to the processed video | |
| # Create the Gradio interface with tabs for image and video detection | |
| with gr.Blocks() as app: | |
| gr.Markdown("## YOLOv8 Object Detection - Image & Video") # Title of the application | |
| gr.Markdown( | |
| "This app detects objects in images and videos using a YOLOv8s model. It detects DURIAN and RAMBUTAN fruits. Use the tabs below to process images or videos." | |
| ) # Description of the application | |
| # Create tabs for image and video detection | |
| with gr.Tabs(): | |
| # Tab for image detection | |
| with gr.TabItem("🖼️ Image Detection"): | |
| with gr.Row(): # Create a row layout | |
| with gr.Column(): # Input column | |
| image_input = gr.Image( | |
| type="pil", label="Input Image", elem_id="image_input", | |
| width=image_width, height=image_height # Fixed dimensions for input | |
| ) | |
| with gr.Column(): # Output column | |
| image_output = gr.Image( | |
| type="pil", label="Output Image", elem_id="image_output", | |
| width=image_width, height=image_height # Fixed dimensions for output | |
| ) | |
| # Button to trigger image detection | |
| image_submit = gr.Button("Detect Durian and Rambutan fruits in Image") | |
| # Link the button click event to the process_image function | |
| image_submit.click(process_image, inputs=image_input, outputs=image_output) | |
| # Tab for video detection | |
| with gr.TabItem("🎥 Video Detection"): | |
| with gr.Row(): # Create a row layout | |
| with gr.Column(): # Input column | |
| video_input = gr.Video( | |
| label="Input Video", elem_id="video_input", | |
| width=video_width, height=video_height # Fixed dimensions for input | |
| ) | |
| with gr.Column(): # Output column | |
| video_output = gr.Video( | |
| label="Output Video", elem_id="video_output", | |
| width=video_width, height=video_height # Fixed dimensions for output | |
| ) | |
| # Button to trigger video detection | |
| video_submit = gr.Button("Detect Durian and Rambutan fruits in Video") | |
| # Link the button click event to the process_video function | |
| video_submit.click(process_video, inputs=video_input, outputs=video_output) | |
| # Launch the Gradio app | |
| app.launch() | |