Spaces:
Sleeping
Sleeping
from ultralytics import YOLO | |
from PIL import Image | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
import os | |
import tempfile | |
import cv2 | |
import zipfile | |
import shutil | |
# === Load model === | |
def load_model(repo_id): | |
download_dir = snapshot_download(repo_id) | |
path = os.path.join(download_dir, "best_int8_openvino_model") | |
return YOLO(path, task='detect') | |
REPO_ID = "sensura/belisha-beacon-zebra-crossing-yoloV8" | |
detection_model = load_model(REPO_ID) | |
# === Image prediction === | |
def predict_image(image, conf_threshold, iou_threshold): | |
result = detection_model.predict(image, conf=conf_threshold, iou=iou_threshold) | |
img_bgr = result[0].plot() | |
out_img = Image.fromarray(img_bgr[..., ::-1]) # Convert BGR to RGB | |
return out_img # Return as PIL | |
# === Video prediction === | |
def predict_video(video, conf_threshold, iou_threshold): | |
cap = cv2.VideoCapture(video) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name | |
out_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
result = detection_model.predict(frame, conf=conf_threshold, iou=iou_threshold) | |
annotated = result[0].plot() | |
out_writer.write(annotated) | |
cap.release() | |
out_writer.release() | |
return out_path | |
# === Multiple images prediction === | |
def predict_multiple(files, conf_threshold, iou_threshold): | |
if not files: | |
return None, None | |
output_dir = tempfile.mkdtemp() | |
annotated_images = [] | |
for file in files: | |
try: | |
img = Image.open(file).convert("RGB") | |
result = detection_model.predict(img, conf=conf_threshold, iou=iou_threshold) | |
img_bgr = result[0].plot() | |
out_img = Image.fromarray(img_bgr[..., ::-1]) # Convert BGR to RGB for PIL | |
out_path = os.path.join(output_dir, os.path.basename(file.name)) | |
out_img.save(out_path) | |
annotated_images.append(out_img) | |
except Exception as e: | |
print(f"Failed to process {file.name}: {e}") | |
zip_path = shutil.make_archive(output_dir, 'zip', output_dir) | |
return annotated_images, zip_path | |
# === Gradio Interfaces === | |
image_tab = gr.Interface( | |
fn=predict_image, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold"), | |
gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold"), | |
], | |
outputs=gr.Image(type="pil", label="Detected Image"), # PIL to avoid color bugs | |
title="Single Image Detection" | |
) | |
video_tab = gr.Interface( | |
fn=predict_video, | |
inputs=[ | |
gr.Video(label="Upload Video"), | |
gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold"), | |
gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold"), | |
], | |
outputs=gr.Video(label="Detected Video"), | |
title="Single Video Detection" | |
) | |
gallery_tab = gr.Interface( | |
fn=predict_multiple, | |
inputs=[ | |
gr.File(file_types=["image"], file_count="multiple", label="Upload Multiple Images"), | |
gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold"), | |
gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold"), | |
], | |
outputs=[ | |
gr.Gallery(label="Detected Gallery", columns=3, height="auto"), | |
gr.File(label="Download Annotated ZIP") | |
], | |
title="Batch Image Detection" | |
) | |
# === Tabbed UI Launch === | |
gr.TabbedInterface( | |
[image_tab, video_tab, gallery_tab], | |
tab_names=["Image", "Video", "Gallery"] | |
).launch(share=True) | |