ml-debi commited on
Commit
1d241a3
1 Parent(s): f5e2541
Files changed (3) hide show
  1. app.py +198 -0
  2. packages.txt +1 -0
  3. requirements.txt +99 -0
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import onnxruntime as ort
4
+ import pytesseract
5
+ from PIL import Image
6
+ import gradio as gr
7
+ import torchvision
8
+ from huggingface_hub import hf_hub_download
9
+
10
+
11
+ app_title = "License Plate Object Detection"
12
+ #model = ["ml-debi/yolov8_license_plate_detection"]
13
+
14
+ def build_tesseract_options(psm=7):
15
+ # tell Tesseract to only OCR alphanumeric characters
16
+ alphanumeric = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
17
+ options = "-c tessedit_char_whitelist={}".format(alphanumeric)
18
+ # set the PSM mode
19
+ options += " --psm {}".format(psm)
20
+ # return the built options string
21
+ return options
22
+
23
+ # Cropped image processing
24
+ def auto_canny(image, sigma=0.33):
25
+ # compute the median of the single channel pixel intensities
26
+ v = np.median(image)
27
+
28
+ # apply automatic Canny edge detection using the computed median
29
+ lower = int(max(0, (1.0 - sigma) * v))
30
+ upper = int(min(255, (1.0 + sigma) * v))
31
+ edged = cv2.Canny(image, lower, upper)
32
+
33
+ # return the edged image
34
+ return edged
35
+
36
+
37
+
38
+ def ocr_image_process(img, sigma, block_size, constant):
39
+ # If the input is a numpy array, convert it to a PIL Image
40
+ if isinstance(img, np.ndarray):
41
+ img = Image.fromarray(img)
42
+
43
+ # Convert the PIL Image back to a numpy array if necessary
44
+ if isinstance(img, Image.Image):
45
+ img = np.array(img)
46
+
47
+ gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
48
+ thresh_inv = cv2.adaptiveThreshold(gray,255,cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRESH_BINARY_INV, int(block_size), int(constant)) #41, 1
49
+ edges = auto_canny(thresh_inv, sigma)
50
+ ctrs, _ = cv2.findContours(edges.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
51
+ sorted_ctrs = sorted(ctrs, key=lambda ctr: cv2.boundingRect(ctr)[0])
52
+ img_area = img.shape[0]*img.shape[1]
53
+ # Create a blank white image
54
+ mask = np.ones(img.shape, dtype="uint8") * 255
55
+
56
+ for i, ctr in enumerate(sorted_ctrs):
57
+ x, y, w, h = cv2.boundingRect(ctr)
58
+ roi_area = w*h
59
+ roi_ratio = roi_area/img_area
60
+ if((roi_ratio >= 0.015) and (roi_ratio < 0.09)):
61
+ if ((h>1.2*w) and (3*w>=h)):
62
+ # Draw filled rectangle (mask) on the mask image
63
+ cv2.rectangle(mask, (x, y), (x+w, y+h), (0,0,0), -1)
64
+
65
+ # Bitwise-or input image and mask to get result
66
+ img = cv2.bitwise_or(img, mask)
67
+ # Convert the image to grayscale (if it isn't already)
68
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
69
+
70
+ return img
71
+
72
+
73
+ def get_detections(image_path, size, ort_session):
74
+ """
75
+ Function to get detections from the model.
76
+ """
77
+ # Check if image_path is a string (indicating a file path)
78
+ if isinstance(image_path, str):
79
+ # Check if the image is a PNG
80
+ if image_path.lower().endswith('.png'):
81
+ # Open the image file
82
+ img = Image.open(image_path)
83
+ # Convert the image to RGB (removes the alpha channel)
84
+ rgb_img = img.convert('RGB')
85
+ # Create a new file name by replacing .png with .jpg
86
+ jpg_image_path = os.path.splitext(image_path)[0] + '.jpg'
87
+ # Save the RGB image as a JPG
88
+ rgb_img.save(jpg_image_path)
89
+ # Update image_path to point to the new JPG image
90
+ image_path = jpg_image_path
91
+
92
+ image = Image.open(image_path)
93
+ # Check if image_path is a NumPy array
94
+ elif isinstance(image_path, np.ndarray):
95
+ image = Image.fromarray(image_path)
96
+ else:
97
+ raise ValueError(
98
+ "image_path must be a file path (str) or a NumPy array.")
99
+
100
+ scale_x = image.width / size
101
+ scale_y = image.height / size
102
+ resized_image = image.resize((size, size))
103
+ transform = torchvision.transforms.ToTensor()
104
+ input_tensor = transform(resized_image).unsqueeze(0)
105
+ outputs = ort_session.run(None, {'images': input_tensor.numpy()})
106
+ return image, outputs, scale_x, scale_y
107
+
108
+
109
+ def non_maximum_supression(outputs, min_confidence):
110
+ """
111
+ Function to apply non-maximum suppression.
112
+ """
113
+ boxes = outputs[0][0]
114
+ confidences = boxes[4]
115
+ max_confidence_index = np.argmax(confidences)
116
+ if confidences[max_confidence_index] > min_confidence:
117
+ return boxes[:, max_confidence_index]
118
+ else:
119
+ return None
120
+
121
+
122
+ def drawings(image, boxes, scale_x, scale_y, sigma, block_size, constant, ocr):
123
+ """
124
+ Function to draw bounding boxes and apply OCR.
125
+ """
126
+ x, y, w, h, c = boxes
127
+ x_min, y_min = (x - w / 2) * scale_x, (y - h / 2) * scale_y
128
+ x_max, y_max = (x + w / 2) * scale_x, (y + h / 2) * scale_y
129
+ license_plate_image = image.crop((x_min, y_min, x_max, y_max))
130
+ processed_cropped_image = ocr_image_process(license_plate_image, sigma, block_size, constant)
131
+
132
+ if ocr == "easyocr":
133
+ import easyocr
134
+ reader = easyocr.Reader(['en'])
135
+ result = reader.readtext(processed_cropped_image)
136
+ try:
137
+ license_plate_text = str.upper(result[0][1])
138
+ except IndexError:
139
+ license_plate_text = "No result found"
140
+ print(license_plate_text)
141
+ else:
142
+ options = build_tesseract_options(7)
143
+ license_plate_text = pytesseract.image_to_string(
144
+ processed_cropped_image,
145
+ config=options)
146
+ print(license_plate_text)
147
+ # Calculate the font scale based on image size
148
+ font_scale = 0.001 * max(image.size)
149
+
150
+ image = cv2.rectangle(np.array(image), (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 0, 255), 3)
151
+ #cv2.putText(image, f'License Plate: {license_plate_text}', (int(x_min), int(y_max)), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), 2)
152
+ cv2.putText(image, f'Confidence: {c:.2f}', (int(x_min), int(y_min)), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), 1)
153
+
154
+ return image, license_plate_image, processed_cropped_image, license_plate_text
155
+
156
+
157
+ def yolo_predictions(image_path, size, sigma, block_size, constant, min_confidence, ort_session, ocr):
158
+ """
159
+ Function to get YOLO predictions.
160
+ """
161
+ image, outputs, scale_x, scale_y = get_detections(
162
+ image_path, size, ort_session)
163
+ boxes = non_maximum_supression(outputs, min_confidence)
164
+ result_img, license_plate_image, processed_cropped_image, license_plate_text = drawings(
165
+ image, boxes, scale_x, scale_y, sigma, block_size, constant, ocr)
166
+ return result_img, license_plate_image, processed_cropped_image, license_plate_text
167
+
168
+
169
+ def predict(image, ocr, sigma, block_size, constant, min_confidence):
170
+
171
+ size = 640
172
+ model_path = "ml-debi/yolov8_license_plate_detection"
173
+ ort_session = ort.InferenceSession(model_path)
174
+
175
+ result_img, _, processed_cropped_image, license_plate_text = yolo_predictions(
176
+ image, size, sigma, block_size, constant, min_confidence, ort_session, ocr)
177
+
178
+ return result_img, processed_cropped_image, license_plate_text
179
+
180
+
181
+ # Add output license plate text, and add examples and description
182
+ iface = gr.Interface(
183
+ fn=predict,
184
+ inputs=[
185
+ "image",
186
+ gr.Dropdown(choices=['pytesseract', 'easyocr'], value="pytesseract", label='OCR Method'),
187
+ gr.Slider(minimum=0, maximum=1, step=0.01, value=0.33, label='Sigma for Auto Canny'),
188
+ gr.Number(value=41, label='Block Size for Adaptive Threshold'),
189
+ gr.Number(value=1, label='Constant for Adaptive Threshold'),
190
+ gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Minimum Confidence for NMS')
191
+ ],
192
+ outputs=[
193
+ gr.Image(label="Predicted image"),
194
+ gr.Image(label="Processed license plate image"),
195
+ gr.Textbox(label="Predicted license plate number")
196
+ ]
197
+ )
198
+ iface.launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ tesseract-ocr
requirements.txt ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.2.0
3
+ annotated-types==0.6.0
4
+ anyio==3.7.1
5
+ astroid==3.0.1
6
+ attrs==23.1.0
7
+ certifi==2023.11.17
8
+ charset-normalizer==3.3.2
9
+ click==8.1.7
10
+ colorama==0.4.6
11
+ coloredlogs==15.0.1
12
+ contourpy==1.1.1
13
+ cycler==0.12.1
14
+ dill==0.3.7
15
+ easyocr==1.7.1
16
+ exceptiongroup==1.2.0
17
+ fastapi==0.104.1
18
+ ffmpy==0.3.1
19
+ filelock==3.13.1
20
+ flatbuffers==23.5.26
21
+ fonttools==4.46.0
22
+ fsspec==2023.12.1
23
+ gradio==4.8.0
24
+ gradio_client==0.7.1
25
+ h11==0.14.0
26
+ httpcore==1.0.2
27
+ httpx==0.25.2
28
+ huggingface-hub==0.19.4
29
+ humanfriendly==10.0
30
+ idna==3.6
31
+ imageio==2.33.0
32
+ importlib-resources==6.1.1
33
+ isort==5.12.0
34
+ Jinja2==3.1.2
35
+ jsonschema==4.20.0
36
+ jsonschema-specifications==2023.11.2
37
+ kiwisolver==1.4.5
38
+ lazy_loader==0.3
39
+ markdown-it-py==3.0.0
40
+ MarkupSafe==2.1.3
41
+ matplotlib==3.7.4
42
+ mccabe==0.7.0
43
+ mdurl==0.1.2
44
+ mpmath==1.3.0
45
+ networkx==3.1
46
+ ninja==1.11.1.1
47
+ numpy==1.24.4
48
+ onnxruntime==1.16.3
49
+ opencv-python==4.8.1.78
50
+ opencv-python-headless==4.8.1.78
51
+ orjson==3.9.10
52
+ packaging==23.2
53
+ pandas==2.0.3
54
+ Pillow==10.1.0
55
+ pkgutil_resolve_name==1.3.10
56
+ platformdirs==4.1.0
57
+ protobuf==4.25.1
58
+ pyclipper==1.3.0.post5
59
+ pydantic==2.5.2
60
+ pydantic_core==2.14.5
61
+ pydub==0.25.1
62
+ Pygments==2.17.2
63
+ pylint==3.0.2
64
+ pyparsing==3.1.1
65
+ pyreadline3==3.4.1
66
+ pytesseract==0.3.10
67
+ python-bidi==0.4.2
68
+ python-dateutil==2.8.2
69
+ python-multipart==0.0.6
70
+ pytz==2023.3.post1
71
+ PyWavelets==1.4.1
72
+ PyYAML==6.0.1
73
+ referencing==0.32.0
74
+ requests==2.31.0
75
+ rich==13.7.0
76
+ rpds-py==0.13.2
77
+ scikit-image==0.21.0
78
+ scipy==1.10.1
79
+ semantic-version==2.10.0
80
+ shapely==2.0.2
81
+ shellingham==1.5.4
82
+ six==1.16.0
83
+ sniffio==1.3.0
84
+ starlette==0.27.0
85
+ sympy==1.12
86
+ tifffile==2023.7.10
87
+ tomli==2.0.1
88
+ tomlkit==0.12.0
89
+ toolz==0.12.0
90
+ torch==2.1.1
91
+ torchvision==0.16.1
92
+ tqdm==4.66.1
93
+ typer==0.9.0
94
+ typing_extensions==4.8.0
95
+ tzdata==2023.3
96
+ urllib3==2.1.0
97
+ uvicorn==0.24.0.post1
98
+ websockets==11.0.3
99
+ zipp==3.17.0