|
|
import gradio as gr |
|
|
import torch |
|
|
from ultralytics import YOLO |
|
|
import os |
|
|
|
|
|
REPO_URL = "https://github.com/WildHackers/community-fish-detector" |
|
|
MODEL_URL = REPO_URL + "/releases/download/cfd-1.00-yolov12x/cfd-yolov12x-1.00.pt" |
|
|
|
|
|
|
|
|
MODEL_PATH = os.path.basename(MODEL_URL) |
|
|
if not os.path.exists(MODEL_PATH): |
|
|
torch.hub.download_url_to_file(MODEL_URL, MODEL_PATH) |
|
|
|
|
|
|
|
|
model = YOLO(MODEL_PATH) |
|
|
|
|
|
def run_detection(input_image, conf_threshold: float = 0.60, iou_threshold: float = 0.45, imgsz: int = 1024): |
|
|
""" |
|
|
Runs YOLOv12x inference on an image. |
|
|
Returns annotated image result. |
|
|
""" |
|
|
if input_image is None: |
|
|
return None |
|
|
|
|
|
results = model.predict( |
|
|
source=input_image, |
|
|
conf=conf_threshold, |
|
|
iou=iou_threshold, |
|
|
imgsz=imgsz, |
|
|
save=False, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
return results[0].plot() |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=run_detection, |
|
|
inputs=[ |
|
|
gr.Image(type="numpy", label="Input Image"), |
|
|
gr.Slider(0, 1, value=0.60, step=0.01, label="Confidence Threshold"), |
|
|
gr.Slider(0, 1, value=0.45, step=0.01, label="IoU Threshold"), |
|
|
gr.Slider(320, 1280, value=1024, step=32, label="Image Size"), |
|
|
], |
|
|
outputs=gr.Image(type="numpy", label="Detected Output"), |
|
|
title="Community Fish Detector (YOLOv12x)", |
|
|
description=( |
|
|
f"Upload an image to detect fish using the [Community Fish Detector]({REPO_URL})." |
|
|
), |
|
|
flagging_mode="never", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|