kbarnard's picture
update description and default imgsz
4b67eff
import gradio as gr
import torch
from ultralytics import YOLO
import os
REPO_URL = "https://github.com/WildHackers/community-fish-detector"
MODEL_URL = REPO_URL + "/releases/download/cfd-1.00-yolov12x/cfd-yolov12x-1.00.pt"
# Download model once
MODEL_PATH = os.path.basename(MODEL_URL)
if not os.path.exists(MODEL_PATH):
torch.hub.download_url_to_file(MODEL_URL, MODEL_PATH)
# Load YOLOv12x model
model = YOLO(MODEL_PATH)
def run_detection(input_image, conf_threshold: float = 0.60, iou_threshold: float = 0.45, imgsz: int = 1024):
"""
Runs YOLOv12x inference on an image.
Returns annotated image result.
"""
if input_image is None:
return None
results = model.predict(
source=input_image,
conf=conf_threshold,
iou=iou_threshold,
imgsz=imgsz,
save=False,
verbose=False
)
return results[0].plot()
# Gradio interface
demo = gr.Interface(
fn=run_detection,
inputs=[
gr.Image(type="numpy", label="Input Image"),
gr.Slider(0, 1, value=0.60, step=0.01, label="Confidence Threshold"),
gr.Slider(0, 1, value=0.45, step=0.01, label="IoU Threshold"),
gr.Slider(320, 1280, value=1024, step=32, label="Image Size"),
],
outputs=gr.Image(type="numpy", label="Detected Output"),
title="Community Fish Detector (YOLOv12x)",
description=(
f"Upload an image to detect fish using the [Community Fish Detector]({REPO_URL})."
),
flagging_mode="never",
)
if __name__ == "__main__":
demo.launch()