JrEasy's picture
Upload app.py
c4a1c35 verified
# -*- coding: utf-8 -*-
"""Judol Gradio YOLO11.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1oiuTAi-cys1ydtUhSDJSRdeA02mAmZQH
"""
import cv2
from ultralytics import YOLO
import gradio as gr
import imageio
model = YOLO('https://huggingface.co/JrEasy/Judol-Detection-YOLO11/resolve/main/best.pt')
confidence_threshold = 0.6
class_names = {
0: "BK8",
1: "Gate of Olympus",
2: "Princess",
3: "Starlight Princess",
4: "Zeus",
}
class_colors = {
0: (0, 255, 0), # Green for BK8
1: (255, 0, 0), # Blue for Gate of Olympus
2: (0, 0, 255), # Red for Princess
3: (255, 255, 0), # Cyan for Starlight Princess
4: (255, 0, 255), # Magenta for Zeus
}
def format_time_ranges(timestamps, classes):
if not timestamps:
return ""
class_timestamps = {}
for timestamp, class_id in zip(timestamps, classes):
class_name = class_names.get(class_id, 'Unknown')
if class_name not in class_timestamps:
class_timestamps[class_name] = []
class_timestamps[class_name].append(timestamp)
formatted_ranges = []
for class_name, timestamps in class_timestamps.items():
timestamps = sorted(timestamps)
ranges = []
start = timestamps[0]
for i in range(1, len(timestamps)):
if timestamps[i] - timestamps[i - 1] <= 1:
continue
else:
ranges.append(f"{int(start)}-{int(timestamps[i - 1])}")
start = timestamps[i]
ranges.append(f"{int(start)}-{int(timestamps[-1])}")
formatted_ranges.append(f"{class_name} = {', '.join(ranges)}")
return ", ".join(formatted_ranges)
import os
def process_video(input_video):
cap = cv2.VideoCapture(input_video)
if not cap.isOpened():
print("Error: Could not open input video.")
return None, []
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Define the output video path in the current directory
output_video_path = os.path.join(os.getcwd(), "processed_video.mp4")
writer = imageio.get_writer(output_video_path, fps=fps, codec="h264")
frame_count = 0
timestamps = []
classes_detected = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
timestamp = frame_count / fps
frame_count += 1
# Resize the frame to 640x640 before passing to the model
resized_frame = cv2.resize(frame, (640, 640))
gray_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2GRAY)
input_frame = cv2.merge([gray_frame, gray_frame, gray_frame])
results = model.predict(input_frame)
for result in results:
for box in result.boxes:
if box.conf[0] >= confidence_threshold:
x1, y1, x2, y2 = map(int, box.xyxy[0])
class_id = int(box.cls[0])
class_name = class_names.get(class_id, f"Class {class_id}")
color = class_colors.get(class_id, (0, 255, 0))
cv2.rectangle(resized_frame, (x1, y1), (x2, y2), color, 2)
text = f'{class_name}, Conf: {box.conf[0]:.2f}'
text_position = (x1, y1 - 10 if y1 > 20 else y1 + 20)
cv2.putText(resized_frame, text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
timestamps.append(timestamp)
classes_detected.append(class_id)
# Resize the frame back to original size for the output video
output_frame = cv2.resize(resized_frame, (frame_width, frame_height))
writer.append_data(cv2.cvtColor(output_frame, cv2.COLOR_BGR2RGB))
cap.release()
writer.close()
formatted_time_ranges = format_time_ranges(timestamps, classes_detected)
print(f"Processed video saved at: {output_video_path}")
return output_video_path, formatted_time_ranges
def process_image(input_image):
# Convert image from RGB to BGR for OpenCV processing
bgr_frame = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
# Resize the frame to 640x640 before passing to the model
resized_frame = cv2.resize(bgr_frame, (640, 640))
# Convert to grayscale and create a 3-channel grayscale image
gray_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2GRAY)
input_frame = cv2.merge([gray_frame, gray_frame, gray_frame])
results = model.predict(input_frame)
detections_log = []
classes_detected = []
for result in results:
for box in result.boxes:
if box.conf[0] >= confidence_threshold:
x1, y1, x2, y2 = map(int, box.xyxy[0])
class_id = int(box.cls[0])
class_name = class_names.get(class_id, f"Class {class_id}")
color = class_colors.get(class_id, (0, 255, 0)) # Default green color
cv2.rectangle(resized_frame, (x1, y1), (x2, y2), color, 2)
text = f'{class_name}, Conf: {box.conf[0]:.2f}'
text_position = (x1, y1 - 10 if y1 > 20 else y1 + 20)
cv2.putText(resized_frame, text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
detections_log.append({
"class": class_name,
"confidence": box.conf[0]
})
classes_detected.append(class_id)
# Count occurrences of each class detected
class_count = {class_names.get(cls, f"Class {cls}"): classes_detected.count(cls) for cls in set(classes_detected)}
# Format the detections as 'Class = Count' pairs
formatted_log = ", ".join([f"{class_name} = {count}" for class_name, count in class_count.items()])
# Convert the output frame back to RGB
output_image = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGB)
return output_image, formatted_log
with gr.Blocks() as app:
gr.Markdown("## Judol Detection using YOLOv11")
with gr.Tab("Video Detection"):
with gr.Row():
input_video = gr.Video(label="Upload a video")
output_video = gr.Video(label="Processed Video")
detections_log = gr.Textbox(label="Detections Log", lines=10)
input_video.change(
fn=lambda input_video: process_video(input_video) if input_video else ("", []),
inputs=input_video,
outputs=[output_video, detections_log],
)
with gr.Tab("Image Detection"):
with gr.Row():
input_image = gr.Image(label="Upload an image")
output_image = gr.Image(label="Processed Image")
image_detections_log = gr.Textbox(label="Detections Log", lines=10)
input_image.change(
fn=process_image,
inputs=input_image,
outputs=[output_image, image_detections_log],
)
app.launch()