Fredrickkk commited on
Commit
f7e770a
1 Parent(s): 3699206

Upload yolov5_plate_onnx_infer.py

Browse files
Files changed (1) hide show
  1. yolov5_plate_onnx_infer.py +282 -0
yolov5_plate_onnx_infer.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import numpy as np
3
+ import cv2
4
+ import copy
5
+ import os
6
+ import argparse
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ import time
9
+ plate_color_list=['black','blue','gree','white','yellow']
10
+
11
+ plateName =[
12
+ "#","jing",'hu','jin','yu','ji','jin','meng','liao','ji','hei','su','zhe','wan','min','gan','lu','yu','e','xiang',
13
+ 'yue','gui','qiong','chuan','gui','yun','zang','shan','gan','qing','ning','xin','xue','jing','gang','ao','gua','shi','ling',
14
+ 'min','hang','wei','0','1','2','3','4','5','6','7','8','9','A','B','C','D','E','F','G','H','J','K','L','M','N','P','Q',
15
+ 'R','S','T','U','V','W','X','Y','Z',
16
+ ]
17
+
18
+ print(len(plateName))
19
+ mean_value,std_value=((0.588,0.193))#
20
+
21
+ def decodePlate(preds): #
22
+ pre=0
23
+ newPreds=[]
24
+ for i in range(len(preds)):
25
+ if preds[i]!=0 and preds[i]!=pre:
26
+ newPreds.append(preds[i])
27
+ pre=preds[i]
28
+ plate=""
29
+ for i in newPreds:
30
+ plate+=plateName[int(i)]
31
+ return plate
32
+ # return newPreds
33
+
34
+ def rec_pre_precessing(img,size=(48,168)): #
35
+ img =cv2.resize(img,(168,48))
36
+ img = img.astype(np.float32)
37
+ img = (img/255-mean_value)/std_value #
38
+ img = img.transpose(2,0,1) #h,w,c 转为 c,h,w
39
+ img = img.reshape(1,*img.shape) #channel,height,width转为batch,channel,height,channel
40
+ return img
41
+
42
+ def get_plate_result(img,session_rec): #
43
+ img =rec_pre_precessing(img)
44
+ y_onnx_plate,y_onnx_color = session_rec.run([session_rec.get_outputs()[0].name,session_rec.get_outputs()[1].name], {session_rec.get_inputs()[0].name: img})
45
+ index =np.argmax(y_onnx_plate,axis=-1)
46
+ index_color = np.argmax(y_onnx_color)
47
+ plate_color = plate_color_list[index_color]
48
+ # print(y_onnx[0])
49
+ plate_no = decodePlate(index[0])
50
+ return plate_no,plate_color
51
+
52
+
53
+ def allFilePath(rootPath,allFIleList): #
54
+ fileList = os.listdir(rootPath)
55
+ for temp in fileList:
56
+ if os.path.isfile(os.path.join(rootPath,temp)):
57
+ allFIleList.append(os.path.join(rootPath,temp))
58
+ else:
59
+ allFilePath(os.path.join(rootPath,temp),allFIleList)
60
+
61
+ def get_split_merge(img): #
62
+ h,w,c = img.shape
63
+ img_upper = img[0:int(5/12*h),:]
64
+ img_lower = img[int(1/3*h):,:]
65
+ img_upper = cv2.resize(img_upper,(img_lower.shape[1],img_lower.shape[0]))
66
+ new_img = np.hstack((img_upper,img_lower))
67
+ return new_img
68
+
69
+
70
+ def order_points(pts): #
71
+ rect = np.zeros((4, 2), dtype = "float32")
72
+ s = pts.sum(axis = 1)
73
+ rect[0] = pts[np.argmin(s)]
74
+ rect[2] = pts[np.argmax(s)]
75
+ diff = np.diff(pts, axis = 1)
76
+ rect[1] = pts[np.argmin(diff)]
77
+ rect[3] = pts[np.argmax(diff)]
78
+ return rect
79
+
80
+
81
+ def four_point_transform(image, pts): #
82
+ rect = order_points(pts)
83
+ (tl, tr, br, bl) = rect
84
+ widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
85
+ widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
86
+ maxWidth = max(int(widthA), int(widthB))
87
+ heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
88
+ heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
89
+ maxHeight = max(int(heightA), int(heightB))
90
+ dst = np.array([
91
+ [0, 0],
92
+ [maxWidth - 1, 0],
93
+ [maxWidth - 1, maxHeight - 1],
94
+ [0, maxHeight - 1]], dtype = "float32")
95
+ M = cv2.getPerspectiveTransform(rect, dst)
96
+ warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
97
+
98
+ # return the warped image
99
+ return warped
100
+
101
+ def my_letter_box(img,size=(640,640)): #
102
+ h,w,c = img.shape
103
+ r = min(size[0]/h,size[1]/w)
104
+ new_h,new_w = int(h*r),int(w*r)
105
+ top = int((size[0]-new_h)/2)
106
+ left = int((size[1]-new_w)/2)
107
+
108
+ bottom = size[0]-new_h-top
109
+ right = size[1]-new_w-left
110
+ img_resize = cv2.resize(img,(new_w,new_h))
111
+ img = cv2.copyMakeBorder(img_resize,top,bottom,left,right,borderType=cv2.BORDER_CONSTANT,value=(114,114,114))
112
+ return img,r,left,top
113
+
114
+ def xywh2xyxy(boxes): #xywh x1,y1 x2,y2
115
+ xywh =copy.deepcopy(boxes)
116
+ xywh[:,0]=boxes[:,0]-boxes[:,2]/2
117
+ xywh[:,1]=boxes[:,1]-boxes[:,3]/2
118
+ xywh[:,2]=boxes[:,0]+boxes[:,2]/2
119
+ xywh[:,3]=boxes[:,1]+boxes[:,3]/2
120
+ return xywh
121
+
122
+ def my_nms(boxes,iou_thresh): #nms
123
+ index = np.argsort(boxes[:,4])[::-1]
124
+ keep = []
125
+ while index.size >0:
126
+ i = index[0]
127
+ keep.append(i)
128
+ x1=np.maximum(boxes[i,0],boxes[index[1:],0])
129
+ y1=np.maximum(boxes[i,1],boxes[index[1:],1])
130
+ x2=np.minimum(boxes[i,2],boxes[index[1:],2])
131
+ y2=np.minimum(boxes[i,3],boxes[index[1:],3])
132
+
133
+ w = np.maximum(0,x2-x1)
134
+ h = np.maximum(0,y2-y1)
135
+
136
+ inter_area = w*h
137
+ union_area = (boxes[i,2]-boxes[i,0])*(boxes[i,3]-boxes[i,1])+(boxes[index[1:],2]-boxes[index[1:],0])*(boxes[index[1:],3]-boxes[index[1:],1])
138
+ iou = inter_area/(union_area-inter_area)
139
+ idx = np.where(iou<=iou_thresh)[0]
140
+ index = index[idx+1]
141
+ return keep
142
+
143
+ def restore_box(boxes,r,left,top): #
144
+ boxes[:,[0,2,5,7,9,11]]-=left
145
+ boxes[:,[1,3,6,8,10,12]]-=top
146
+
147
+ boxes[:,[0,2,5,7,9,11]]/=r
148
+ boxes[:,[1,3,6,8,10,12]]/=r
149
+ return boxes
150
+
151
+ def detect_pre_precessing(img,img_size): #
152
+ img,r,left,top=my_letter_box(img,img_size)
153
+ # cv2.imwrite("1.jpg",img)
154
+ img =img[:,:,::-1].transpose(2,0,1).copy().astype(np.float32)
155
+ img=img/255
156
+ img=img.reshape(1,*img.shape)
157
+ return img,r,left,top
158
+
159
+ def post_precessing(dets,r,left,top,conf_thresh=0.3,iou_thresh=0.5):#
160
+ choice = dets[:,:,4]>conf_thresh
161
+ dets=dets[choice]
162
+ dets[:,13:15]*=dets[:,4:5]
163
+ box = dets[:,:4]
164
+ boxes = xywh2xyxy(box)
165
+ score= np.max(dets[:,13:15],axis=-1,keepdims=True)
166
+ index = np.argmax(dets[:,13:15],axis=-1).reshape(-1,1)
167
+ output = np.concatenate((boxes,score,dets[:,5:13],index),axis=1)
168
+ reserve_=my_nms(output,iou_thresh)
169
+ output=output[reserve_]
170
+ output = restore_box(output,r,left,top)
171
+ return output
172
+
173
+ def rec_plate(outputs,img0,session_rec): #
174
+ dict_list=[]
175
+ for output in outputs:
176
+ result_dict={}
177
+ rect=output[:4].tolist()
178
+ land_marks = output[5:13].reshape(4,2)
179
+ roi_img = four_point_transform(img0,land_marks)
180
+ label = int(output[-1])
181
+ score = output[4]
182
+ if label==1: #
183
+ roi_img = get_split_merge(roi_img)
184
+ plate_no,plate_color = get_plate_result(roi_img,session_rec)
185
+ result_dict['rect']=rect
186
+ result_dict['landmarks']=land_marks.tolist()
187
+ result_dict['plate_no']=plate_no
188
+ result_dict['roi_height']=roi_img.shape[0]
189
+ result_dict['plate_color']=plate_color
190
+ dict_list.append(result_dict)
191
+ return dict_list
192
+
193
+ def cv2ImgAddText(img, text, left, top, textColor=(0, 255, 0), textSize=20): #
194
+ if (isinstance(img, np.ndarray)): #
195
+ img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
196
+ draw = ImageDraw.Draw(img)
197
+ fontText = ImageFont.truetype(
198
+ "fonts/platech.ttf", textSize, encoding="utf-8")
199
+ draw.text((left, top), text, textColor, font=fontText)
200
+ return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
201
+
202
+ def draw_result(orgimg,dict_list):
203
+ result_str =""
204
+ for result in dict_list:
205
+ rect_area = result['rect']
206
+
207
+ x,y,w,h = rect_area[0],rect_area[1],rect_area[2]-rect_area[0],rect_area[3]-rect_area[1]
208
+ padding_w = 0.05*w
209
+ padding_h = 0.11*h
210
+ rect_area[0]=max(0,int(x-padding_w))
211
+ rect_area[1]=min(orgimg.shape[1],int(y-padding_h))
212
+ rect_area[2]=max(0,int(rect_area[2]+padding_w))
213
+ rect_area[3]=min(orgimg.shape[0],int(rect_area[3]+padding_h))
214
+
215
+ height_area = result['roi_height']
216
+ landmarks=result['landmarks']
217
+ result = result['plate_no']
218
+ result_str+=result+" "
219
+ for i in range(4): #
220
+ cv2.circle(orgimg, (int(landmarks[i][0]), int(landmarks[i][1])), 5, clors[i], -1)
221
+ cv2.rectangle(orgimg,(rect_area[0],rect_area[1]),(rect_area[2],rect_area[3]),(255,255,0),2) #
222
+ if len(result)>=1:
223
+ orgimg=cv2ImgAddText(orgimg,result,rect_area[0]-height_area,rect_area[1]-height_area-10,(0,255,0),height_area)
224
+ print(result_str)
225
+ return orgimg
226
+
227
+ def init_car_plate_detect_model(model_path,providers):
228
+ session_detect = onnxruntime.InferenceSession(model_path, providers=providers )
229
+ return session_detect
230
+
231
+ def init_car_plate_rec_model(model_path,providers):
232
+ session_rec = onnxruntime.InferenceSession(model_path, providers=providers )
233
+ return session_rec
234
+
235
+ def detect_plate(img,session_detect,img_size):
236
+ img,r,left,top = detect_pre_precessing(img,(img_size,img_size)) #
237
+ # print(img.shape)
238
+ y_onnx = session_detect.run([session_detect.get_outputs()[0].name], {session_detect.get_inputs()[0].name: img})[0]
239
+ outputs = post_precessing(y_onnx,r,left,top) #
240
+ return outputs
241
+
242
+
243
+ # if __name__ == "__main__":
244
+ # begin = time.time()
245
+ # parser = argparse.ArgumentParser()
246
+ # parser.add_argument('--detect_model',type=str, default=r'E:/study/Object_detect/Chinese_license_plate_detection_recognition-main/weights/best_exp3.onnx', help='model.pt path(s)') #检测模型
247
+ # parser.add_argument('--rec_model', type=str, default='weights/best_rec.onnx', help='model.pt path(s)')#
248
+ # parser.add_argument('--image_path', type=str, default=r'E:\study\onnx_runtime\vehicle_type_brand\WJ', help='source')
249
+ # parser.add_argument('--img_size', type=int, default=640, help='inference size (pixels)')
250
+ # parser.add_argument('--output', type=str, default='result', help='source')
251
+ # opt = parser.parse_args()
252
+ # file_list = []
253
+ # allFilePath(opt.image_path,file_list)
254
+ # providers = ['CUDAExecutionProvider']
255
+ # clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)]
256
+ # img_size = (opt.img_size,opt.img_size)
257
+ # session_detect = init_car_plate_detect_model(opt.detect_model, providers=providers )
258
+ # session_rec = init_car_plate_rec_model(opt.rec_model, providers=providers )
259
+ # if not os.path.exists(opt.output):
260
+ # os.mkdir(opt.output)
261
+ # save_path = opt.output
262
+ # count = 0
263
+ # for pic_ in file_list:
264
+ # count+=1
265
+ # print(count,pic_,end=" ")
266
+ # img=cv2.imread(pic_)
267
+ # img0 = copy.deepcopy(img)
268
+ # img,r,left,top = detect_pre_precessing(img,img_size) #
269
+ # # print(img.shape)
270
+ # # img=np.concatenate((img,img))
271
+ # # print(img)
272
+ # y_onnx = session_detect.run([session_detect.get_outputs()[0].name], {session_detect.get_inputs()[0].name: img})[0]
273
+ # outputs = post_precessing(y_onnx,r,left,top) #
274
+ # result_list=rec_plate(outputs,img0,session_rec)
275
+ # ori_img = draw_result(img0,result_list)
276
+ # img_name = os.path.basename(pic_)
277
+ # save_img_path = os.path.join(save_path,img_name)
278
+ # cv2.imwrite(save_img_path,ori_img)
279
+ # print(f"cost time{time.time()-begin} s")
280
+
281
+
282
+