import os
import gradio as gr
import torch
from PIL import Image
import numpy as np
from rfdetr import RFDETRMedium
import supervision as sv
import cv2
# --- Load RF-DETR model ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model = None # Initialize model as None
try:
weights_path = "weights.pth" # Adjust this if your weights file is elsewhere
if not os.path.exists(weights_path):
print(f"WARNING: Model weights file not found at: {weights_path}. Loading default RF-DETR weights.")
model = RFDETRMedium() # Load default COCO pre-trained weights if custom weights not found
else:
model = RFDETRMedium(pretrain_weights=weights_path)
print("RF-DETR model loaded successfully.")
except Exception as e:
print(f"An error occurred during RF-DETR model loading: {e}")
print("Please ensure 'weights.pth' is valid or remove pretrain_weights for default loading.")
exit() # Exit if model cannot be loaded
box_annotator = sv.BoxAnnotator()
# Inference
def detect_objects_image(image_np, threshold=0.5):
if image_np is None:
return np.zeros((480, 640, 3), dtype=np.uint8), 0
# Convert NumPy array to PIL Image for model.predict()
pil_img = Image.fromarray(image_np)
detections = model.predict(pil_img, threshold=threshold)
annotated_image_np = box_annotator.annotate(image_np.copy(), detections)
return annotated_image_np, len(detections.class_id)
def detect_objects_video(video_path, threshold=0.5):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open video file {video_path}")
return None
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
output_path = "output_annotated.mp4"
# Get video properties for VideoWriter
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
if not out.isOpened():
print(f"Error: Could not create video writer for {output_path}")
cap.release()
return None
total_detections_count = 0 # Initialize counter for video
while True:
ret, frame = cap.read()
if not ret:
break
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(img_rgb)
detections = model.predict(pil_img, threshold=threshold)
total_detections_count += len(detections.class_id) # Accumulate detections
annotated_np = box_annotator.annotate(img_rgb.copy(), detections)
annotated_bgr = cv2.cvtColor(annotated_np, cv2.COLOR_RGB2BGR)
out.write(annotated_bgr)
cap.release()
out.release()
print(f"Processed video saved to: {output_path}")
return output_path
# --- Gradio interface ---
# Keeping the CSS definition, though the classes are not explicitly used in the current layout
css = ".my-group {max-width: 600px !important; max-height: 600px !important;}"
examples = [
[os.path.join("examples", "image1.jpg"), 0.5],
[os.path.join("examples", "image2.jpg"), 0.5],
[os.path.join("examples", "image3.jpg"), 0.5],
]
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
Dense Retail Object Detection
"""
)
with gr.Tab("Image"):
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="numpy",
label="Upload Image",
height=400, # Fixed height
width=400 # Fixed width
)
threshold_slider = gr.Slider(
0.0, 1.0, value=0.5, step=0.05,
label="Detection Threshold"
)
with gr.Column(scale=1):
output_image = gr.Image(
type="numpy",
label="Annotated Image",
height=400, # Match input
width=400 # Match input
)
output_count = gr.Number(label="Total Detections")
btn = gr.Button("Detect Objects")
btn.click(
fn=detect_objects_image,
inputs=[image_input, threshold_slider],
outputs=[output_image, output_count]
)
gr.Examples(
examples=examples,
inputs=[image_input, threshold_slider],
outputs=[output_image, output_count],
label="Example Images"
)
with gr.Tab("Video"):
with gr.Row(): # Use gr.Row for left-right layout
with gr.Column(scale=1):
video_input = gr.Video(label="Upload Video (.mp4, .mov)")
threshold_slider_vid = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Detection Threshold")
with gr.Column(scale=1):
output_video = gr.Video(label="Annotated Video")
btn_video = gr.Button("Process Video")
btn_video.click(fn=detect_objects_video,
inputs=[video_input, threshold_slider_vid],
outputs=[output_video])
gr.Examples(
examples=[
["examples/video1.mp4", 0.5]
],
inputs=[video_input, threshold_slider_vid],
outputs=[output_video],
label="Example Videos"
)
with gr.Tab("Webcam"):
# markdown
gr.Markdown("""
### Webcam Input
Use your webcam to capture images or record videos first then click button for object detection.
""")
with gr.Row():
with gr.Column(scale=1):
# Webcam input for single image capture
webcam_image_input = gr.Image(label="Capture Image from Webcam", sources="webcam")
btn_capture_image = gr.Button("Process Image")
gr.Markdown("---") # Separator for clarity
# Webcam input for recording video
webcam_video_input = gr.Video(label="Record Video from Webcam", sources="webcam")
btn_process_recorded_video = gr.Button("Process Recorded Video")
threshold_slider_webcam = gr.Slider( # Single threshold slider for both image and video processing
label="Detection Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.5,
)
with gr.Column(scale=1): # Right column for outputs
# Output for image detection
output_image_webcam = gr.Image(label="Annotated Image Output")
output_count_image_webcam = gr.Number(label="Total Detections (Image)")
gr.Markdown("---") # Separator for clarity
# Output for video detection
output_video_webcam_rec = gr.Video(label="Processed Recorded Video Output")
#output_count_video_webcam_rec = gr.Number(label="Total Detections (Video)")
# Event for single image capture and detection
btn_capture_image.click(
fn=detect_objects_image, # This returns image, count
inputs=[webcam_image_input, threshold_slider_webcam],
outputs=[output_image_webcam, output_count_image_webcam]
)
# Event for recorded video processing
btn_process_recorded_video.click(
fn=detect_objects_video, # This returns video_path, total_detections_count
inputs=[webcam_video_input, threshold_slider_webcam],
outputs=[output_video_webcam_rec] # Outputs both video and its count
)
# Launch the demo
demo.launch(share=True)