zhengrongzhang wangfangyuan commited on
Commit
1620f89
1 Parent(s): e0810ef

Create test_infer_onnx.py (#6)

Browse files

- Create test_infer_onnx.py (290f357c21d3fe2a506fcccd008d4df6f47d7b05)


Co-authored-by: fangyuan wang <wangfangyuan@users.noreply.huggingface.co>

Files changed (1) hide show
  1. test_infer_onnx.py +154 -0
test_infer_onnx.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import onnxruntime
4
+ import numpy as np
5
+ import argparse
6
+ from utils import (
7
+ LoadImages,
8
+ non_max_suppression,
9
+ plot_images,
10
+ output_to_target,
11
+ )
12
+ import sys
13
+ import pathlib
14
+ CURRENT_DIR = pathlib.Path(__file__).parent
15
+ sys.path.append(str(CURRENT_DIR))
16
+ from optimum.amd.ryzenai import RyzenAIModelForObjectDetection
17
+
18
+ def preprocess(img):
19
+ img = torch.from_numpy(img)
20
+ img = img.float() # uint8 to fp16/32
21
+ img /= 255 # 0 - 255 to 0.0 - 1.0
22
+ return img
23
+
24
+
25
+ class DFL(nn.Module):
26
+ # Integral module of Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
27
+ def __init__(self, c1=16):
28
+ super().__init__()
29
+ self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
30
+ x = torch.arange(c1, dtype=torch.float)
31
+ self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
32
+ self.c1 = c1
33
+
34
+ def forward(self, x):
35
+ b, c, a = x.shape # batch, channels, anchors
36
+ return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(
37
+ b, 4, a
38
+ )
39
+
40
+
41
+ def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
42
+ """Transform distance(ltrb) to box(xywh or xyxy)."""
43
+ lt, rb = torch.split(distance, 2, dim)
44
+ x1y1 = anchor_points - lt
45
+ x2y2 = anchor_points + rb
46
+ if xywh:
47
+ c_xy = (x1y1 + x2y2) / 2
48
+ wh = x2y2 - x1y1
49
+ return torch.cat((c_xy, wh), dim) # xywh bbox
50
+ return torch.cat((x1y1, x2y2), dim) # xyxy bbox
51
+
52
+
53
+ def post_process(x):
54
+ dfl = DFL(16)
55
+ anchors = torch.tensor(
56
+ np.load(
57
+ "./anchors.npy",
58
+ allow_pickle=True,
59
+ )
60
+ )
61
+ strides = torch.tensor(
62
+ np.load(
63
+ "./strides.npy",
64
+ allow_pickle=True,
65
+ )
66
+ )
67
+ box, cls = torch.cat([xi.view(x[0].shape[0], 144, -1) for xi in x], 2).split(
68
+ (16 * 4, 80), 1
69
+ )
70
+ dbox = dist2bbox(dfl(box), anchors.unsqueeze(0), xywh=True, dim=1) * strides
71
+ y = torch.cat((dbox, cls.sigmoid()), 1)
72
+ return y, x
73
+
74
+
75
+ def make_parser():
76
+ parser = argparse.ArgumentParser("onnxruntime inference sample")
77
+ parser.add_argument(
78
+ "-m",
79
+ "--onnx_model",
80
+ type=str,
81
+ default="./yolov8m.onnx",
82
+ help="input your onnx model.",
83
+ )
84
+ parser.add_argument(
85
+ "-i",
86
+ "--image_path",
87
+ type=str,
88
+ default='./demo.jpg',
89
+ help="path to your input image.",
90
+ )
91
+ parser.add_argument(
92
+ "-o",
93
+ "--output_path",
94
+ type=str,
95
+ default='./demo_infer.jpg',
96
+ help="path to your output directory.",
97
+ )
98
+ parser.add_argument(
99
+ "--ipu", action='store_true', help='flag for ryzen ai'
100
+ )
101
+ parser.add_argument(
102
+ "--provider_config", default='', type=str, help='provider config for ryzen ai'
103
+ )
104
+ return parser
105
+
106
+ classnames = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
107
+ 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
108
+ 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
109
+ 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
110
+ 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
111
+ 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
112
+ 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
113
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
114
+ 'hair drier', 'toothbrush']
115
+ names = {k: classnames[k] for k in range(80)}
116
+ imgsz = [640, 640]
117
+
118
+
119
+ if __name__ == '__main__':
120
+ args = make_parser().parse_args()
121
+ source = args.image_path
122
+ dataset = LoadImages(
123
+ source, imgsz=imgsz, stride=32, auto=False, transforms=None, vid_stride=1
124
+ )
125
+ onnx_weight = args.onnx_model
126
+ if args.ipu:
127
+ onnx_model = RyzenAIModelForObjectDetection.from_pretrained(".\\", vaip_config=args.provider_config)
128
+ # providers = ["VitisAIExecutionProvider"]
129
+ # provider_options = [{"config_file": args.provider_config}]
130
+ # onnx_model = onnxruntime.InferenceSession(onnx_weight, providers=providers, provider_options=provider_options)
131
+ else:
132
+ onnx_model = onnxruntime.InferenceSession(onnx_weight)
133
+ for batch in dataset:
134
+ path, im, im0s, vid_cap, s = batch
135
+ im = preprocess(im)
136
+ if len(im.shape) == 3:
137
+ im = im[None]
138
+ # outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: im.cpu().numpy()})
139
+ # outputs = [torch.tensor(item) for item in outputs]
140
+ # outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: im.permute(0, 2, 3, 1).cpu().numpy()})
141
+ # outputs = [torch.tensor(item).permute(0, 3, 1, 2) for item in outputs]
142
+ outputs = onnx_model(im.permute(0, 2, 3, 1))
143
+ outputs = [outputs[0].permute(0, 3, 1, 2), outputs[1].permute(0, 3, 1, 2), outputs[2].permute(0, 3, 1, 2)]
144
+ preds = post_process(outputs)
145
+ preds = non_max_suppression(
146
+ preds, 0.25, 0.7, agnostic=False, max_det=300, classes=None
147
+ )
148
+ plot_images(
149
+ im,
150
+ *output_to_target(preds, max_det=15),
151
+ source,
152
+ fname=args.output_path,
153
+ names=names,
154
+ )