lohitkavuru14 commited on
Commit
6f0cd6e
1 Parent(s): efeb88d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -368
app.py CHANGED
@@ -1,368 +0,0 @@
1
- import fileinput
2
- import itertools
3
- import os
4
- import re
5
- from copy import deepcopy
6
- from operator import itemgetter
7
- from pathlib import Path
8
- from typing import Union
9
-
10
- import cv2 # type: ignore
11
- import gradio as gr # type: ignore
12
- import numpy as np
13
- import torch
14
- from deep_sort_realtime.deepsort_tracker import DeepSort # type: ignore
15
- from paddleocr import PaddleOCR # type: ignore
16
-
17
- if not os.path.isfile("weights.pt"):
18
- weights_url = "https://archive.org/download/anpr_weights/weights.pt"
19
- os.system(f"wget {weights_url}")
20
-
21
- if not os.path.isdir("examples"):
22
- examples_url = "https://archive.org/download/anpr_examples_202208/examples.tar.gz"
23
- os.system(f"wget {examples_url}")
24
- os.system("tar -xvf examples.tar.gz")
25
- os.system("rm -rf examples.tar.gz")
26
-
27
-
28
- def prepend_text(filename: Union[str, Path], text: str):
29
- with fileinput.input(filename, inplace=True) as file:
30
- for line in file:
31
- if file.isfirstline():
32
- print(text)
33
- print(line, end="")
34
-
35
-
36
- if not os.path.isdir("yolov7"):
37
- yolov7_repo_url = "https://github.com/WongKinYiu/yolov7"
38
- os.system(f"git clone {yolov7_repo_url}")
39
- # Fix import errors
40
- for file in [
41
- "yolov7/models/common.py",
42
- "yolov7/models/experimental.py",
43
- "yolov7/models/yolo.py",
44
- "yolov7/utils/datasets.py",
45
- ]:
46
- prepend_text(file, "import sys\nsys.path.insert(0, './yolov7')")
47
-
48
- from yolov7.models.experimental import attempt_load # type: ignore
49
- from yolov7.utils.datasets import letterbox # type: ignore
50
- from yolov7.utils.general import check_img_size # type: ignore
51
- from yolov7.utils.general import non_max_suppression # type: ignore
52
- from yolov7.utils.general import scale_coords # type: ignore
53
- from yolov7.utils.plots import plot_one_box # type: ignore
54
- from yolov7.utils.torch_utils import TracedModel, select_device # type: ignore
55
-
56
- weights = "weights.pt"
57
- device_id = "cpu"
58
- image_size = 640
59
- trace = True
60
-
61
- # Initialize
62
- device = select_device(device_id)
63
- half = device.type != "cpu" # half precision only supported on CUDA
64
-
65
- # Load model
66
- model = attempt_load(weights, map_location=device) # load FP32 model
67
- stride = int(model.stride.max()) # model stride
68
- imgsz = check_img_size(image_size, s=stride) # check img_size
69
-
70
- if trace:
71
- model = TracedModel(model, device, image_size)
72
-
73
- if half:
74
- model.half() # to FP16
75
-
76
- if device.type != "cpu":
77
- model(
78
- torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))
79
- ) # run once
80
-
81
- model.eval()
82
-
83
- # Load OCR
84
-
85
- paddle = PaddleOCR(lang="en")
86
-
87
-
88
- def detect_plate(source_image):
89
- # Padded resize
90
- img_size = 640
91
- stride = 32
92
- img = letterbox(source_image, img_size, stride=stride)[0]
93
-
94
- # Convert
95
- img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
96
- img = np.ascontiguousarray(img)
97
- img = torch.from_numpy(img).to(device)
98
- img = img.half() if half else img.float() # uint8 to fp16/32
99
- img /= 255.0 # 0 - 255 to 0.0 - 1.0
100
- if img.ndimension() == 3:
101
- img = img.unsqueeze(0)
102
-
103
- with torch.no_grad():
104
- # Inference
105
- pred = model(img, augment=True)[0]
106
-
107
- # Apply NMS
108
- pred = non_max_suppression(pred, 0.25, 0.45, classes=0, agnostic=True)
109
-
110
- plate_detections = []
111
- det_confidences = []
112
-
113
- # Process detections
114
- for i, det in enumerate(pred): # detections per image
115
- if len(det):
116
- # Rescale boxes from img_size to source image size
117
- det[:, :4] = scale_coords(
118
- img.shape[2:], det[:, :4], source_image.shape
119
- ).round()
120
-
121
- # Return results
122
- for *xyxy, conf, cls in reversed(det):
123
- coords = [
124
- int(position)
125
- for position in (torch.tensor(xyxy).view(1, 4)).tolist()[0]
126
- ]
127
- plate_detections.append(coords)
128
- det_confidences.append(conf.item())
129
-
130
- return plate_detections, det_confidences
131
-
132
-
133
- def unsharp_mask(image, kernel_size=(5, 5), sigma=1.0, amount=2.0, threshold=0):
134
- blurred = cv2.GaussianBlur(image, kernel_size, sigma)
135
- sharpened = float(amount + 1) * image - float(amount) * blurred
136
- sharpened = np.maximum(sharpened, np.zeros(sharpened.shape))
137
- sharpened = np.minimum(sharpened, 255 * np.ones(sharpened.shape))
138
- sharpened = sharpened.round().astype(np.uint8)
139
- if threshold > 0:
140
- low_contrast_mask = np.absolute(image - blurred) < threshold
141
- np.copyto(sharpened, image, where=low_contrast_mask)
142
- return sharpened
143
-
144
-
145
- def crop(image, coord):
146
- cropped_image = image[int(coord[1]) : int(coord[3]), int(coord[0]) : int(coord[2])]
147
- return cropped_image
148
-
149
-
150
- def ocr_plate(plate_region):
151
- # Image pre-processing for more accurate OCR
152
- rescaled = cv2.resize(
153
- plate_region, None, fx=1.2, fy=1.2, interpolation=cv2.INTER_CUBIC
154
- )
155
- grayscale = cv2.cvtColor(rescaled, cv2.COLOR_BGR2GRAY)
156
- kernel = np.ones((1, 1), np.uint8)
157
- dilated = cv2.dilate(grayscale, kernel, iterations=1)
158
- eroded = cv2.erode(dilated, kernel, iterations=1)
159
- sharpened = unsharp_mask(eroded)
160
-
161
- # OCR the preprocessed image
162
- results = paddle.ocr(sharpened, det=False, cls=False)
163
- flattened = list(itertools.chain.from_iterable(results))
164
- plate_text, ocr_confidence = max(flattened, key=itemgetter(1), default=("", 0))
165
-
166
- # Filter out anything but uppercase letters, digits, hypens and whitespace.
167
- plate_text = re.sub(r"[^-A-Z0-9 ]", r"", plate_text).strip()
168
-
169
- if ocr_confidence == "nan":
170
- ocr_confidence = 0
171
-
172
- return plate_text, ocr_confidence
173
-
174
-
175
- def get_plates_from_image(input):
176
- if input is None:
177
- return None
178
- plate_detections, det_confidences = detect_plate(input)
179
- plate_texts = []
180
- ocr_confidences = []
181
- detected_image = deepcopy(input)
182
- for coords in plate_detections:
183
- plate_region = crop(input, coords)
184
- plate_text, ocr_confidence = ocr_plate(plate_region)
185
- if ocr_confidence == 0: # If OCR confidence is 0, skip this detection
186
- continue
187
- plate_texts.append(plate_text)
188
- ocr_confidences.append(ocr_confidence)
189
- plot_one_box(
190
- coords,
191
- detected_image,
192
- label=plate_text,
193
- color=[0, 150, 255],
194
- line_thickness=2,
195
- )
196
- return detected_image
197
-
198
-
199
- def pascal_voc_to_coco(x1y1x2y2):
200
- x1, y1, x2, y2 = x1y1x2y2
201
- return [x1, y1, x2 - x1, y2 - y1]
202
-
203
-
204
- def get_best_ocr(preds, rec_conf, ocr_res, track_id):
205
- for info in preds:
206
- # Check if it is current track id
207
- if info["track_id"] == track_id:
208
- # Check if the ocr confidenence is maximum or not
209
- if info["ocr_conf"] < rec_conf:
210
- info["ocr_conf"] = rec_conf
211
- info["ocr_txt"] = ocr_res
212
- else:
213
- rec_conf = info["ocr_conf"]
214
- ocr_res = info["ocr_txt"]
215
- break
216
- return preds, rec_conf, ocr_res
217
-
218
-
219
- def get_plates_from_video(source):
220
- if source is None:
221
- return None
222
-
223
- # Create a VideoCapture object
224
- video = cv2.VideoCapture(source)
225
-
226
- # Default resolutions of the frame are obtained. The default resolutions are system dependent.
227
- # We convert the resolutions from float to integer.
228
- width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
229
- height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
230
- fps = video.get(cv2.CAP_PROP_FPS)
231
-
232
- # Define the codec and create VideoWriter object.
233
- temp = f"{Path(source).stem}_temp{Path(source).suffix}"
234
- export = cv2.VideoWriter(
235
- temp, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
236
- )
237
-
238
- # Intializing tracker
239
- tracker = DeepSort(embedder_gpu=False)
240
-
241
- # Initializing some helper variables.
242
- preds = []
243
- total_obj = 0
244
-
245
- while True:
246
- ret, frame = video.read()
247
- if ret == True:
248
- # Run the ANPR algorithm
249
- bboxes, scores = detect_plate(frame)
250
- # Convert Pascal VOC detections to COCO
251
- bboxes = list(map(lambda bbox: pascal_voc_to_coco(bbox), bboxes))
252
-
253
- if len(bboxes) > 0:
254
- # Storing all the required info in a list.
255
- detections = [
256
- (bbox, score, "number_plate") for bbox, score in zip(bboxes, scores)
257
- ]
258
-
259
- # Applying tracker.
260
- # The tracker code flow: kalman filter -> target association(using hungarian algorithm) and appearance descriptor.
261
- tracks = tracker.update_tracks(detections, frame=frame)
262
-
263
- # Checking if tracks exist.
264
- for track in tracks:
265
- if not track.is_confirmed() or track.time_since_update > 1:
266
- continue
267
-
268
- # Changing track bbox to top left, bottom right coordinates
269
- bbox = [int(position) for position in list(track.to_tlbr())]
270
-
271
- for i in range(len(bbox)):
272
- if bbox[i] < 0:
273
- bbox[i] = 0
274
-
275
- # Cropping the license plate and applying the OCR.
276
- plate_region = crop(frame, bbox)
277
- plate_text, ocr_confidence = ocr_plate(plate_region)
278
-
279
- # Storing the ocr output for corresponding track id.
280
- output_frame = {
281
- "track_id": track.track_id,
282
- "ocr_txt": plate_text,
283
- "ocr_conf": ocr_confidence,
284
- }
285
-
286
- # Appending track_id to list only if it does not exist in the list
287
- # else looking for the current track in the list and updating the highest confidence of it.
288
- if track.track_id not in list(
289
- set(pred["track_id"] for pred in preds)
290
- ):
291
- total_obj += 1
292
- preds.append(output_frame)
293
- else:
294
- preds, ocr_confidence, plate_text = get_best_ocr(
295
- preds, ocr_confidence, plate_text, track.track_id
296
- )
297
-
298
- # Plotting the prediction.
299
- plot_one_box(
300
- bbox,
301
- frame,
302
- label=f"{str(track.track_id)}. {plate_text}",
303
- color=[255, 150, 0],
304
- line_thickness=3,
305
- )
306
-
307
- # Write the frame into the output file
308
- export.write(frame)
309
- else:
310
- break
311
-
312
- # When everything done, release the video capture and video write objects
313
-
314
- video.release()
315
- export.release()
316
-
317
- # Compressing the output video for smaller size and web compatibility.
318
- output = f"{Path(source).stem}_detected{Path(source).suffix}"
319
- os.system(
320
- f"ffmpeg -y -i {temp} -c:v libx264 -b:v 5000k -minrate 1000k -maxrate 8000k -pass 1 -c:a aac -f mp4 /dev/null && ffmpeg -i {temp} -c:v libx264 -b:v 5000k -minrate 1000k -maxrate 8000k -pass 2 -c:a aac -movflags faststart {output}"
321
- )
322
- os.system(f"rm -rf {temp} ffmpeg2pass-0.log ffmpeg2pass-0.log.mbtree")
323
-
324
- return output
325
-
326
-
327
- with gr.Blocks() as demo:
328
- gr.Markdown('### <h3 align="center">Automatic Number Plate Recognition</h3>')
329
- gr.Markdown(
330
- "This AI was trained to detect and recognize number plates on vehicles."
331
- )
332
- with gr.Tabs():
333
- with gr.TabItem("Image"):
334
- with gr.Row():
335
- image_input = gr.Image()
336
- image_output = gr.Image()
337
- image_input.change(
338
- get_plates_from_image, inputs=image_input, outputs=image_output
339
- )
340
- gr.Examples(
341
- [
342
- ["examples/test_image_1.jpg"],
343
- ["examples/test_image_2.jpg"],
344
- ["examples/test_image_3.png"],
345
- ["examples/test_image_4.jpeg"],
346
- ],
347
- [image_input],
348
- image_output,
349
- get_plates_from_image,
350
- cache_examples=True,
351
- )
352
- with gr.TabItem("Video"):
353
- with gr.Row():
354
- video_input = gr.Video(format="mp4")
355
- video_output = gr.Video(format="mp4")
356
- video_input.change(
357
- get_plates_from_video, inputs=video_input, outputs=video_output
358
- )
359
- gr.Examples(
360
- [["examples/test_video_1.mp4"]],
361
- [video_input],
362
- video_output,
363
- get_plates_from_video,
364
- cache_examples=True,
365
- )
366
- gr.Markdown("[@itsyoboieltr](https://github.com/itsyoboieltr)")
367
-
368
- demo.launch()