SpyC0der77 commited on
Commit
97f0b00
Β·
verified Β·
1 Parent(s): 795111b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from ultralytics import YOLO
3
+ from huggingface_hub import hf_hub_download
4
+ from PIL import Image
5
+
6
+ MODEL_REPOS = {
7
+ "Footprint YOLO": "risashinoda/footprint_yolo",
8
+ "Feces YOLO": "risashinoda/feces_yolo",
9
+ "Egg YOLO": "risashinoda/egg_yolo",
10
+ "Bone YOLO": "risashinoda/bone_yolo",
11
+ "Feather YOLO": "risashinoda/feather_yolo"
12
+ }
13
+
14
+ _loaded = {}
15
+
16
+ def _load(model_key, weights_name="last.pt"):
17
+ if model_key not in _loaded:
18
+ repo_id = MODEL_REPOS[model_key]
19
+ w = hf_hub_download(repo_id=repo_id, filename=weights_name)
20
+ _loaded[model_key] = YOLO(w)
21
+ return _loaded[model_key]
22
+
23
+ def infer(image, model_key, conf_thres=None, iou_nms=None, draw_labels=None):
24
+ conf_thres = 0.25 if conf_thres is None else float(conf_thres)
25
+ iou_nms = 0.70 if iou_nms is None else float(iou_nms)
26
+ draw_labels = True if draw_labels is None else bool(draw_labels)
27
+
28
+ model = _load(model_key)
29
+ results = model.predict(image, conf=conf_thres, iou=iou_nms)
30
+ r = results[0]
31
+
32
+ plotted = r.plot()
33
+ img_out = Image.fromarray(plotted[..., ::-1])
34
+
35
+ if not draw_labels:
36
+ import numpy as np, cv2
37
+ img = results[0].orig_img.copy()
38
+ if img.ndim == 2:
39
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
40
+ for box in r.boxes.xyxy.tolist():
41
+ x1, y1, x2, y2 = map(int, box)
42
+ cv2.rectangle(img, (x1, y1), (x2, y2), (0, 150, 0), 3)
43
+ img_out = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
44
+
45
+ return img_out
46
+
47
+ with gr.Blocks() as demo:
48
+ gr.Markdown("# Multi-YOLO Demo (BBox only)")
49
+ gr.Markdown("## πŸ”Ž **Select a model first**\nChoose one below, then upload an image.")
50
+
51
+ with gr.Row():
52
+ img_in = gr.Image(type="pil", label="Upload an image")
53
+
54
+ with gr.Column():
55
+ # ここは Radio にする
56
+ model_dd = gr.Radio(
57
+ choices=list(MODEL_REPOS.keys()),
58
+ value=list(MODEL_REPOS.keys())[0],
59
+ label="Select Model",
60
+ interactive=True
61
+ )
62
+
63
+ conf = gr.Slider(0.05, 1.0, value=0.25, step=0.01, label="Confidence (default 0.25)")
64
+ iou = gr.Slider(0.1, 0.95, value=0.70, step=0.01, label="NMS IoU (default 0.70)")
65
+ draw_labels = gr.Checkbox(value=True, label="Draw labels text (off = boxes only)")
66
+ run_btn = gr.Button("Run")
67
+
68
+ img_out = gr.Image(type="pil", label="Detections (boxes only)")
69
+
70
+ run_btn.click(fn=infer, inputs=[img_in, model_dd, conf, iou, draw_labels], outputs=[img_out])
71
+
72
+ if __name__ == "__main__":
73
+ demo.launch()