232975Q / app.py
sensura's picture
Update app.py
ea8ef48 verified
from ultralytics import YOLO
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
import os
import tempfile
import cv2
import zipfile
import shutil
# === Load model ===
def load_model(repo_id):
download_dir = snapshot_download(repo_id)
path = os.path.join(download_dir, "best_int8_openvino_model")
return YOLO(path, task='detect')
REPO_ID = "sensura/belisha-beacon-zebra-crossing-yoloV8"
detection_model = load_model(REPO_ID)
# === Image prediction ===
def predict_image(image, conf_threshold, iou_threshold):
result = detection_model.predict(image, conf=conf_threshold, iou=iou_threshold)
img_bgr = result[0].plot()
out_img = Image.fromarray(img_bgr[..., ::-1]) # Convert BGR to RGB
return out_img # Return as PIL
# === Video prediction ===
def predict_video(video, conf_threshold, iou_threshold):
cap = cv2.VideoCapture(video)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
out_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
result = detection_model.predict(frame, conf=conf_threshold, iou=iou_threshold)
annotated = result[0].plot()
out_writer.write(annotated)
cap.release()
out_writer.release()
return out_path
# === Multiple images prediction ===
def predict_multiple(files, conf_threshold, iou_threshold):
if not files:
return None, None
output_dir = tempfile.mkdtemp()
annotated_images = []
for file in files:
try:
img = Image.open(file).convert("RGB")
result = detection_model.predict(img, conf=conf_threshold, iou=iou_threshold)
img_bgr = result[0].plot()
out_img = Image.fromarray(img_bgr[..., ::-1]) # Convert BGR to RGB for PIL
out_path = os.path.join(output_dir, os.path.basename(file.name))
out_img.save(out_path)
annotated_images.append(out_img)
except Exception as e:
print(f"Failed to process {file.name}: {e}")
zip_path = shutil.make_archive(output_dir, 'zip', output_dir)
return annotated_images, zip_path
# === Gradio Interfaces ===
image_tab = gr.Interface(
fn=predict_image,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold"),
gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold"),
],
outputs=gr.Image(type="pil", label="Detected Image"), # PIL to avoid color bugs
title="Single Image Detection"
)
video_tab = gr.Interface(
fn=predict_video,
inputs=[
gr.Video(label="Upload Video"),
gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold"),
gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold"),
],
outputs=gr.Video(label="Detected Video"),
title="Single Video Detection"
)
gallery_tab = gr.Interface(
fn=predict_multiple,
inputs=[
gr.File(file_types=["image"], file_count="multiple", label="Upload Multiple Images"),
gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold"),
gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold"),
],
outputs=[
gr.Gallery(label="Detected Gallery", columns=3, height="auto"),
gr.File(label="Download Annotated ZIP")
],
title="Batch Image Detection"
)
# === Tabbed UI Launch ===
gr.TabbedInterface(
[image_tab, video_tab, gallery_tab],
tab_names=["Image", "Video", "Gallery"]
).launch(share=True)