Spaces:
Sleeping
Sleeping
# Import necessary libraries | |
import gradio as gr # Gradio for building the web interface | |
from ultralytics import YOLO # YOLO model for object detection | |
from PIL import Image # PIL for image processing | |
import os # OS module for file operations | |
import cv2 # OpenCV for video processing | |
import tempfile # Temporary file creation | |
# Load the pre-trained YOLOv8 model | |
model = YOLO('./best_yolo8_model/best.pt') # Path to the trained YOLO model | |
# Define fixed dimensions for consistent display of input/output | |
image_height = 300 # Height of the image display | |
image_width = 400 # Width of the image display | |
video_height = 300 # Height of the video display | |
video_width = 400 # Width of the video display | |
# Function to process an uploaded image | |
def process_image(image): | |
""" | |
Perform object detection on an uploaded image and return the processed image. | |
Args: | |
image (PIL.Image): Input image uploaded by the user. | |
Returns: | |
PIL.Image: Processed image with detection results. | |
""" | |
results = model.predict(source=image, conf=0.5) # Perform inference with confidence threshold | |
result_img = results[0].plot() # Plot the detection results on the image | |
# Convert the processed image from BGR to RGB to maintain color consistency | |
return Image.fromarray(cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)) | |
# Function to process an uploaded video | |
def process_video(video_path): | |
""" | |
Perform object detection on an uploaded video and return the processed video. | |
Args: | |
video_path (str): Path to the uploaded video file. | |
Returns: | |
str: Path to the processed video file. | |
""" | |
temp_dir = tempfile.mkdtemp() # Create a temporary directory for the output video | |
output_video_path = os.path.join(temp_dir, "processed_video.mp4") # Path to save the processed video | |
cap = cv2.VideoCapture(video_path) # Open the video file | |
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # Get the video frame width | |
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Get the video frame height | |
fps = cap.get(cv2.CAP_PROP_FPS) # Get the video frame rate | |
# Define the codec and create a VideoWriter object | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # MP4 codec | |
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height)) | |
# Process each frame of the video | |
while cap.isOpened(): | |
ret, frame = cap.read() # Read a frame | |
if not ret: # Break the loop if no frame is read | |
break | |
results = model.predict(source=frame, conf=0.5) # Perform inference on the frame | |
result_frame = results[0].plot() # Plot detection results on the frame | |
# Correct the color format of the frame before saving | |
result_frame_corrected = cv2.cvtColor(result_frame, cv2.COLOR_BGR2RGB) | |
out.write(cv2.cvtColor(result_frame_corrected, cv2.COLOR_RGB2BGR)) # Write the frame to the output video | |
cap.release() # Release the video capture object | |
out.release() # Release the video writer object | |
return output_video_path # Return the path to the processed video | |
# Create the Gradio interface with tabs for image and video detection | |
with gr.Blocks() as app: | |
gr.Markdown("## YOLOv8 Object Detection - Image & Video") # Title of the application | |
gr.Markdown( | |
"This app detects objects in images and videos using a YOLOv8s model. It detects DURIAN and RAMBUTAN fruits. Use the tabs below to process images or videos." | |
) # Description of the application | |
# Create tabs for image and video detection | |
with gr.Tabs(): | |
# Tab for image detection | |
with gr.TabItem("🖼️ Image Detection"): | |
with gr.Row(): # Create a row layout | |
with gr.Column(): # Input column | |
image_input = gr.Image( | |
type="pil", label="Input Image", elem_id="image_input", | |
width=image_width, height=image_height # Fixed dimensions for input | |
) | |
with gr.Column(): # Output column | |
image_output = gr.Image( | |
type="pil", label="Output Image", elem_id="image_output", | |
width=image_width, height=image_height # Fixed dimensions for output | |
) | |
# Button to trigger image detection | |
image_submit = gr.Button("Detect Durian and Rambutan fruits in Image") | |
# Link the button click event to the process_image function | |
image_submit.click(process_image, inputs=image_input, outputs=image_output) | |
# Tab for video detection | |
with gr.TabItem("🎥 Video Detection"): | |
with gr.Row(): # Create a row layout | |
with gr.Column(): # Input column | |
video_input = gr.Video( | |
label="Input Video", elem_id="video_input", | |
width=video_width, height=video_height # Fixed dimensions for input | |
) | |
with gr.Column(): # Output column | |
video_output = gr.Video( | |
label="Output Video", elem_id="video_output", | |
width=video_width, height=video_height # Fixed dimensions for output | |
) | |
# Button to trigger video detection | |
video_submit = gr.Button("Detect Durian and Rambutan fruits in Video") | |
# Link the button click event to the process_video function | |
video_submit.click(process_video, inputs=video_input, outputs=video_output) | |
# Launch the Gradio app | |
app.launch() | |