xiang-wuu
commited on
Commit
•
da9c9dd
1
Parent(s):
60fa02e
inference and integration using gradio app
Browse files
app.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
|
4 |
+
os.system("wget https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5s.pt")
|
5 |
+
os.system("wget https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5s6.pt")
|
6 |
+
|
7 |
+
from models.experimental import attempt_load
|
8 |
+
from utils.augmentations import letterbox
|
9 |
+
from utils.plots import Annotator
|
10 |
+
from utils.general import non_max_suppression, scale_coords
|
11 |
+
from utils.torch_utils import *
|
12 |
+
import sys
|
13 |
+
import numpy as np
|
14 |
+
import random
|
15 |
+
|
16 |
+
|
17 |
+
def detect(img, weights):
|
18 |
+
gpu_id="cuda:0"
|
19 |
+
device = select_device(device=gpu_id)
|
20 |
+
model = attempt_load(weights+'.pt', device=device)
|
21 |
+
torch.no_grad()
|
22 |
+
model.to(device).eval()
|
23 |
+
half = False # half precision only supported on CUDA
|
24 |
+
if half:
|
25 |
+
model.half()
|
26 |
+
|
27 |
+
img_size = 640
|
28 |
+
|
29 |
+
# Get names and colors
|
30 |
+
names = model.names if hasattr(model, 'names') else model.modules.names
|
31 |
+
colors = [[random.randint(0, 255) for _ in range(3)]
|
32 |
+
for _ in range(len(names))]
|
33 |
+
if img is None:
|
34 |
+
sys.exit(0)
|
35 |
+
|
36 |
+
# Run inference
|
37 |
+
t0 = time_sync()
|
38 |
+
|
39 |
+
im0 = img.copy()
|
40 |
+
img = letterbox(img, img_size, stride=int(model.stride.max()), auto=False and True)[0]
|
41 |
+
img = np.stack(img, 0)
|
42 |
+
|
43 |
+
img = img.transpose((2, 0, 1))[::-1] # BGR to RGB, to 3x416x416
|
44 |
+
|
45 |
+
img = np.ascontiguousarray(img)
|
46 |
+
|
47 |
+
img = torch.from_numpy(img).to(device)
|
48 |
+
|
49 |
+
if half:
|
50 |
+
img = img.half()
|
51 |
+
else:
|
52 |
+
img = img.float() # if half else img.float() # uint8 to fp16/32
|
53 |
+
img /= 255.0 # 0 - 255 to 0.0 - 1.0
|
54 |
+
if len(img.shape) == 3:
|
55 |
+
img = img[None] # expand for batch dim
|
56 |
+
# Inference
|
57 |
+
t1 = time_sync()
|
58 |
+
pred = model(img, augment=False, profile=False)[0]
|
59 |
+
|
60 |
+
# to float
|
61 |
+
if half:
|
62 |
+
pred = pred.float()
|
63 |
+
|
64 |
+
# Apply NMS
|
65 |
+
pred = non_max_suppression(
|
66 |
+
pred, 0.1, 0.5, classes=None, agnostic=False)
|
67 |
+
t2 = time_sync()
|
68 |
+
annotator = Annotator(im0, line_width=3, example=str(names))
|
69 |
+
# Process detections
|
70 |
+
for i, det in enumerate(pred): # detections per image
|
71 |
+
s = ''
|
72 |
+
s += '%gx%g ' % img.shape[2:] # print string
|
73 |
+
if det is not None and len(det):
|
74 |
+
# Rescale boxes from img_size to im0 size
|
75 |
+
det[:, :4] = scale_coords(
|
76 |
+
img.shape[2:], det[:, :4], im0.shape).round()
|
77 |
+
|
78 |
+
# Print results
|
79 |
+
for c in det[:, -1].unique():
|
80 |
+
n = (det[:, -1] == c).sum() # detections per class
|
81 |
+
s += '%g %ss, ' % (n, names[int(c)]) # add to string
|
82 |
+
|
83 |
+
# show results
|
84 |
+
for *xyxy, conf, cls in det:
|
85 |
+
label = '%s %.2f' % (names[int(cls)], conf)
|
86 |
+
annotator.box_label(xyxy, label, color=colors[int(cls)])
|
87 |
+
im0 = annotator.result()
|
88 |
+
# Print time (inference + NMS)
|
89 |
+
infer_time = t2 - t1
|
90 |
+
|
91 |
+
print('%sDone. %s' %
|
92 |
+
(s, infer_time))
|
93 |
+
|
94 |
+
print('Done. (%.3fs)' % (time.time() - t0))
|
95 |
+
|
96 |
+
return im0
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == '__main__':
|
100 |
+
gr.Interface(detect,[gr.Image(type="numpy"),gr.Dropdown(choices=["yolov5s","yolov5s6"])],
|
101 |
+
gr.Image(type="numpy"),title="Yolov5",examples=[["data/images/bus.jpg", "yolov5s"]],
|
102 |
+
description="Gradio based demo for <a href='https://github.com/ultralytics/yolov5' style='text-decoration: underline' target='_blank'>ultralytics/yolov5</a>, new state-of-the-art for real-time object detection").launch()
|