Spaces:
Paused
Paused
import gradio as gr | |
import cv2 | |
import requests | |
import os | |
import random | |
from ultralytics import YOLO | |
import numpy as np | |
from collections import defaultdict | |
import sqlite3 | |
import time | |
# Import the supervision library | |
import supervision as sv | |
# --- Initialize SQLite DB for logging --- | |
conn = sqlite3.connect("detection_log.db", check_same_thread=False) | |
cursor = conn.cursor() | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS detections ( | |
timestamp REAL, | |
frame_number INTEGER, | |
bin_name TEXT, | |
class_name TEXT, | |
count INTEGER | |
) | |
''') | |
conn.commit() | |
# --- File Downloading --- | |
# File URLs for sample images and video | |
file_urls = [ | |
'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/mix2.jpg?download=true', | |
'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/mix11.jpg?download=true', | |
'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/sample_waste.mp4?download=true', | |
] | |
def download_file(url, save_name): | |
"""Downloads a file from a URL, overwriting if it exists.""" | |
print(f"Downloading from: {url}") | |
try: | |
response = requests.get(url, stream=True) | |
response.raise_for_status() # Check for HTTP errors | |
with open(save_name, 'wb') as file: | |
for chunk in response.iter_content(1024): | |
file.write(chunk) | |
print(f"Downloaded and overwrote: {save_name}") | |
except requests.exceptions.RequestException as e: | |
print(f"Error downloading {url}: {e}") | |
# Download sample images and video for the examples | |
for i, url in enumerate(file_urls): | |
if 'mp4' in url: | |
download_file(url, "video.mp4") | |
else: | |
download_file(url, f"image_{i}.jpg") | |
# --- Model and Class Configuration --- | |
# Load your custom YOLO model | |
# IMPORTANT: Replace 'best.pt' with the path to your model trained on the 12 classes. | |
model = YOLO('best.pt') | |
# Get class names and generate colors dynamically from the loaded model | |
# This is the best practice as it ensures names and colors match the model's output. | |
class_names = model.model.names | |
class_colors = { | |
name: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) | |
for name in class_names.values() | |
} | |
# Define paths for Gradio examples | |
image_example_paths = [['image_0.jpg'], ['image_1.jpg']] | |
video_example_path = [['video.mp4']] | |
# --- Image Processing Function --- | |
def show_preds_image(image_path): | |
"""Processes a single image and overlays YOLO predictions.""" | |
image = cv2.imread(image_path) | |
outputs = model.predict(source=image_path, verbose=False) | |
results = outputs[0].cpu().numpy() | |
# Convert to supervision Detections object for easier handling | |
detections = sv.Detections.from_ultralytics(outputs[0]) | |
# Annotate the image with bounding boxes and labels | |
for i, (box, conf, cls) in enumerate(zip(detections.xyxy, detections.confidence, detections.class_id)): | |
x1, y1, x2, y2 = map(int, box) | |
class_name = class_names[cls] | |
color = class_colors[class_name] | |
# Draw bounding box | |
cv2.rectangle(image, (x1, y1), (x2, y2), color=color, thickness=2, lineType=cv2.LINE_AA) | |
# Create and display label | |
label = f"{class_name}: {conf:.2f}" | |
cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA) | |
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# --- Video Processing Function (with Supervision) --- | |
def process_video_with_two_side_bins(video_path): | |
generator = sv.get_video_frames_generator(video_path) | |
try: | |
first_frame = next(generator) | |
except StopIteration: | |
blank_frame = np.zeros((480, 640, 3), dtype=np.uint8) | |
yield cv2.cvtColor(blank_frame, cv2.COLOR_BGR2RGB) | |
return | |
frame_height, frame_width, _ = first_frame.shape | |
bins = [ | |
{ | |
"name": "Recycle Bin", | |
"coords": ( | |
int(frame_width * 0.05), | |
int(frame_height * 0.5), | |
int(frame_width * 0.25), | |
int(frame_height * 0.95), | |
), | |
"color": (200, 16, 46), # Blue-ish | |
}, | |
{ | |
"name": "Trash Bin", | |
"coords": ( | |
int(frame_width * 0.75), | |
int(frame_height * 0.5), | |
int(frame_width * 0.95), | |
int(frame_height * 0.95), | |
), | |
"color": (50, 50, 50), # Red-ish | |
}, | |
] | |
box_annotator = sv.BoxAnnotator(thickness=2) | |
label_annotator = sv.LabelAnnotator( | |
text_scale=1.2, | |
text_thickness=3, | |
text_position=sv.Position.TOP_LEFT, | |
) | |
tracker = sv.ByteTrack() | |
items_in_bins = {bin_["name"]: set() for bin_ in bins} | |
class_counts_per_bin = {bin_["name"]: defaultdict(int) for bin_ in bins} | |
frame_number = 0 | |
BATCH_SIZE = 10 | |
LOGGED_OBJECT_TTL_SECONDS = 300 # 5 minutes | |
insert_buffer = [] | |
logged_objects = {} | |
for frame in generator: | |
frame_number += 1 | |
current_time = time.time() | |
# Prune old logged objects every BATCH_SIZE frames | |
if frame_number % BATCH_SIZE == 0: | |
keys_to_remove = [key for key, ts in logged_objects.items() | |
if current_time - ts > LOGGED_OBJECT_TTL_SECONDS] | |
for key in keys_to_remove: | |
del logged_objects[key] | |
results = model(frame, verbose=False)[0] | |
detections = sv.Detections.from_ultralytics(results) | |
tracked_detections = tracker.update_with_detections(detections) | |
annotated_frame = frame.copy() | |
# Draw bins and labels | |
for bin_ in bins: | |
x1, y1, x2, y2 = bin_["coords"] | |
color = bin_["color"] | |
cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color=color, thickness=3) | |
cv2.putText( | |
annotated_frame, | |
bin_["name"], | |
(x1 + 5, y1 - 15), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1.5, | |
color, | |
3, | |
cv2.LINE_AA, | |
) | |
if tracked_detections.tracker_id is None: | |
yield cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) | |
continue | |
# Clear counts for this frame | |
for bin_name in class_counts_per_bin: | |
class_counts_per_bin[bin_name].clear() | |
for box, track_id, class_id in zip( | |
tracked_detections.xyxy, | |
tracked_detections.tracker_id, | |
tracked_detections.class_id, | |
): | |
x1, y1, x2, y2 = map(int, box) | |
cx = (x1 + x2) // 2 | |
cy = (y1 + y2) // 2 | |
class_name = class_names[class_id] | |
for bin_ in bins: | |
bx1, by1, bx2, by2 = bin_["coords"] | |
bin_name = bin_["name"] | |
if (bx1 <= cx <= bx2) and (by1 <= cy <= by2): | |
key = (track_id, bin_name, class_name) | |
if track_id not in items_in_bins[bin_name]: | |
items_in_bins[bin_name].add(track_id) | |
class_counts_per_bin[bin_name][class_name] += 1 | |
if key not in logged_objects: | |
timestamp = time.time() | |
insert_buffer.append((timestamp, frame_number, bin_name, class_name, 1)) | |
logged_objects[key] = current_time | |
# Batch insert every BATCH_SIZE frames | |
if frame_number % BATCH_SIZE == 0 and insert_buffer: | |
cursor.executemany(''' | |
INSERT INTO detections (timestamp, frame_number, bin_name, class_name, count) | |
VALUES (?, ?, ?, ?, ?) | |
''', insert_buffer) | |
conn.commit() | |
insert_buffer.clear() | |
labels = [ | |
f"#{tid} {class_names[cid]}" | |
for cid, tid in zip(tracked_detections.class_id, tracked_detections.tracker_id) | |
] | |
annotated_frame = box_annotator.annotate( | |
scene=annotated_frame, detections=tracked_detections | |
) | |
annotated_frame = label_annotator.annotate( | |
scene=annotated_frame, detections=tracked_detections, labels=labels | |
) | |
# Display counts per bin | |
y_pos = 50 | |
for bin_name, class_count_dict in class_counts_per_bin.items(): | |
text = ( | |
f"{bin_name}: " | |
+ ", ".join(f"{cls}={count}" for cls, count in class_count_dict.items()) | |
) | |
cv2.putText( | |
annotated_frame, | |
text, | |
(30, y_pos), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1.1, | |
(255, 255, 255), | |
3, | |
cv2.LINE_AA, | |
) | |
y_pos += 40 | |
yield cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) | |
# Insert any remaining buffered data at end | |
if insert_buffer: | |
cursor.executemany(''' | |
INSERT INTO detections (timestamp, frame_number, bin_name, class_name, count) | |
VALUES (?, ?, ?, ?, ?) | |
''', insert_buffer) | |
conn.commit() | |
insert_buffer.clear() | |
# --- Gradio Interface Setup --- | |
# Gradio Interface for Image Processing | |
interface_image = gr.Interface( | |
fn=show_preds_image, | |
inputs=gr.Image(type="filepath", label="Input Image"), | |
outputs=gr.Image(type="numpy", label="Output Image"), | |
title="Waste Detection (Image)", | |
description="Upload an image to see waste detection results.", | |
examples=image_example_paths, | |
cache_examples=False, | |
) | |
# Gradio Interface for Video Processing | |
interface_video = gr.Interface( | |
fn=process_video_with_two_side_bins, | |
inputs=gr.Video(label="Input Video"), | |
outputs=gr.Image(type="numpy", label="Output Video Stream"), | |
title="Waste Tracking and Counting (Video)", | |
description="Upload a video to see real-time object tracking and counting.", | |
examples=video_example_path, | |
cache_examples=False, | |
) | |
# Launch the Gradio App with separate tabs for each interface | |
gr.TabbedInterface( | |
[interface_image, interface_video], | |
tab_names=['Image Inference', 'Video Inference'] | |
).queue().launch(debug=True) |