File size: 3,826 Bytes
e9a13e0
 
 
 
 
d11116c
 
5c5412c
 
e9a13e0
8e2e3e2
e9a13e0
 
288c8ce
d11116c
e9a13e0
8e2e3e2
 
 
368ea00
 
 
 
28d61cf
 
368ea00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e2e3e2
 
 
 
 
649ed43
8e2e3e2
 
 
 
 
 
5c5412c
 
52b1b0a
8e2e3e2
 
 
 
 
d11116c
8e2e3e2
 
 
 
 
368ea00
 
8e2e3e2
ea8ef48
8e2e3e2
 
 
28d61cf
368ea00
 
 
 
 
 
 
 
 
8e2e3e2
368ea00
 
8e2e3e2
e9a13e0
368ea00
8e2e3e2
288c8ce
8e2e3e2
 
 
288c8ce
649ed43
8e2e3e2
 
649ed43
368ea00
5c5412c
 
8e2e3e2
 
368ea00
 
8e2e3e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from ultralytics import YOLO
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
import os
import tempfile
import cv2
import zipfile
import shutil

# === Load model ===
def load_model(repo_id):
    download_dir = snapshot_download(repo_id)
    path = os.path.join(download_dir, "best_int8_openvino_model")
    return YOLO(path, task='detect')

REPO_ID = "sensura/belisha-beacon-zebra-crossing-yoloV8"
detection_model = load_model(REPO_ID)

# === Image prediction ===
def predict_image(image, conf_threshold, iou_threshold):
    result = detection_model.predict(image, conf=conf_threshold, iou=iou_threshold)
    img_bgr = result[0].plot()
    out_img = Image.fromarray(img_bgr[..., ::-1])  # Convert BGR to RGB
    return out_img  # Return as PIL

# === Video prediction ===
def predict_video(video, conf_threshold, iou_threshold):
    cap = cv2.VideoCapture(video)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
    out_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        result = detection_model.predict(frame, conf=conf_threshold, iou=iou_threshold)
        annotated = result[0].plot()
        out_writer.write(annotated)

    cap.release()
    out_writer.release()
    return out_path

# === Multiple images prediction ===
def predict_multiple(files, conf_threshold, iou_threshold):
    if not files:
        return None, None

    output_dir = tempfile.mkdtemp()
    annotated_images = []

    for file in files:
        try:
            img = Image.open(file).convert("RGB")
            result = detection_model.predict(img, conf=conf_threshold, iou=iou_threshold)
            img_bgr = result[0].plot()
            out_img = Image.fromarray(img_bgr[..., ::-1])  # Convert BGR to RGB for PIL
            out_path = os.path.join(output_dir, os.path.basename(file.name))
            out_img.save(out_path)
            annotated_images.append(out_img)
        except Exception as e:
            print(f"Failed to process {file.name}: {e}")

    zip_path = shutil.make_archive(output_dir, 'zip', output_dir)
    return annotated_images, zip_path

# === Gradio Interfaces ===

image_tab = gr.Interface(
    fn=predict_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold"),
        gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold"),
    ],
    outputs=gr.Image(type="pil", label="Detected Image"),  # PIL to avoid color bugs
    title="Single Image Detection"
)

video_tab = gr.Interface(
    fn=predict_video,
    inputs=[
        gr.Video(label="Upload Video"),
        gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold"),
        gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold"),
    ],
    outputs=gr.Video(label="Detected Video"),
    title="Single Video Detection"
)

gallery_tab = gr.Interface(
    fn=predict_multiple,
    inputs=[
        gr.File(file_types=["image"], file_count="multiple", label="Upload Multiple Images"),
        gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold"),
        gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold"),
    ],
    outputs=[
        gr.Gallery(label="Detected Gallery", columns=3, height="auto"),
        gr.File(label="Download Annotated ZIP")
    ],
    title="Batch Image Detection"
)

# === Tabbed UI Launch ===
gr.TabbedInterface(
    [image_tab, video_tab, gallery_tab],
    tab_names=["Image", "Video", "Gallery"]
).launch(share=True)