7123558N / app.py
b4one's picture
add commented code
1a8f1da
# Import necessary libraries
import gradio as gr # Gradio for building the web interface
from ultralytics import YOLO # YOLO model for object detection
from PIL import Image # PIL for image processing
import os # OS module for file operations
import cv2 # OpenCV for video processing
import tempfile # Temporary file creation
# Load the pre-trained YOLOv8 model
model = YOLO('./best_yolo8_model/best.pt') # Path to the trained YOLO model
# Define fixed dimensions for consistent display of input/output
image_height = 300 # Height of the image display
image_width = 400 # Width of the image display
video_height = 300 # Height of the video display
video_width = 400 # Width of the video display
# Function to process an uploaded image
def process_image(image):
"""
Perform object detection on an uploaded image and return the processed image.
Args:
image (PIL.Image): Input image uploaded by the user.
Returns:
PIL.Image: Processed image with detection results.
"""
results = model.predict(source=image, conf=0.5) # Perform inference with confidence threshold
result_img = results[0].plot() # Plot the detection results on the image
# Convert the processed image from BGR to RGB to maintain color consistency
return Image.fromarray(cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB))
# Function to process an uploaded video
def process_video(video_path):
"""
Perform object detection on an uploaded video and return the processed video.
Args:
video_path (str): Path to the uploaded video file.
Returns:
str: Path to the processed video file.
"""
temp_dir = tempfile.mkdtemp() # Create a temporary directory for the output video
output_video_path = os.path.join(temp_dir, "processed_video.mp4") # Path to save the processed video
cap = cv2.VideoCapture(video_path) # Open the video file
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # Get the video frame width
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Get the video frame height
fps = cap.get(cv2.CAP_PROP_FPS) # Get the video frame rate
# Define the codec and create a VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # MP4 codec
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
# Process each frame of the video
while cap.isOpened():
ret, frame = cap.read() # Read a frame
if not ret: # Break the loop if no frame is read
break
results = model.predict(source=frame, conf=0.5) # Perform inference on the frame
result_frame = results[0].plot() # Plot detection results on the frame
# Correct the color format of the frame before saving
result_frame_corrected = cv2.cvtColor(result_frame, cv2.COLOR_BGR2RGB)
out.write(cv2.cvtColor(result_frame_corrected, cv2.COLOR_RGB2BGR)) # Write the frame to the output video
cap.release() # Release the video capture object
out.release() # Release the video writer object
return output_video_path # Return the path to the processed video
# Create the Gradio interface with tabs for image and video detection
with gr.Blocks() as app:
gr.Markdown("## YOLOv8 Object Detection - Image & Video") # Title of the application
gr.Markdown(
"This app detects objects in images and videos using a YOLOv8s model. It detects DURIAN and RAMBUTAN fruits. Use the tabs below to process images or videos."
) # Description of the application
# Create tabs for image and video detection
with gr.Tabs():
# Tab for image detection
with gr.TabItem("🖼️ Image Detection"):
with gr.Row(): # Create a row layout
with gr.Column(): # Input column
image_input = gr.Image(
type="pil", label="Input Image", elem_id="image_input",
width=image_width, height=image_height # Fixed dimensions for input
)
with gr.Column(): # Output column
image_output = gr.Image(
type="pil", label="Output Image", elem_id="image_output",
width=image_width, height=image_height # Fixed dimensions for output
)
# Button to trigger image detection
image_submit = gr.Button("Detect Durian and Rambutan fruits in Image")
# Link the button click event to the process_image function
image_submit.click(process_image, inputs=image_input, outputs=image_output)
# Tab for video detection
with gr.TabItem("🎥 Video Detection"):
with gr.Row(): # Create a row layout
with gr.Column(): # Input column
video_input = gr.Video(
label="Input Video", elem_id="video_input",
width=video_width, height=video_height # Fixed dimensions for input
)
with gr.Column(): # Output column
video_output = gr.Video(
label="Output Video", elem_id="video_output",
width=video_width, height=video_height # Fixed dimensions for output
)
# Button to trigger video detection
video_submit = gr.Button("Detect Durian and Rambutan fruits in Video")
# Link the button click event to the process_video function
video_submit.click(process_video, inputs=video_input, outputs=video_output)
# Launch the Gradio app
app.launch()