Spaces:
Runtime error
Runtime error
File size: 5,639 Bytes
e4f8fe4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import cv2
import numpy as np
from ultralytics import YOLO
import tensorflow as tf
class YoloLeNetOCR:
def __init__(self,
yolo_model_path: str,
lenet_model_path: str,
image_size=(28, 28),
conf_threshold=0.25):
# YOLO detector
self.detector = YOLO(yolo_model_path)
# LeNet CNN
self.cnn = tf.keras.models.load_model(lenet_model_path)
# Embedded class names and inverse map (no external pkl needed)
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'C', 'dot']
self.inv_map = {i: label for i, label in enumerate(class_names)}
# params
self.image_size = image_size
self.conf_threshold = conf_threshold
def preprocess(self, crop: np.ndarray) -> np.ndarray:
# Convert to grayscale for CNN
gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
resized = cv2.resize(gray, self.image_size)
normed = resized.astype(np.float32) / 255.0
# CNN expects shape (1, H, W, 1)
return normed.reshape(1, *self.image_size, 1)
def ocr_image(self, image_path: str) -> str:
# 1) Load as single-channel grayscale
gray0 = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
if gray0 is None:
raise FileNotFoundError(f"Cannot read {image_path}")
# 2) Convert to BGR by stacking gray into 3 channels
img = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR)
# 3) Detect boxes on the grayscale-derived BGR image
res = self.detector.predict(source=img, verbose=False)[0]
boxes = res.boxes.xyxy.cpu().numpy()
confs = res.boxes.conf.cpu().numpy()
classes = res.boxes.cls.cpu().numpy().astype(int)
boxes = boxes[confs >= self.conf_threshold]
classes = classes[confs >= self.conf_threshold]
if len(boxes) == 0:
return ""
# 4) Sort boxes left-to-right
sort_idx = np.argsort(boxes[:, 0])
boxes = boxes[sort_idx]
classes = classes[sort_idx]
digits = []
for (x1, y1, x2, y2), cls_idx in zip(boxes, classes):
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
crop = img[
max(0, y1):min(img.shape[0], y2),
max(0, x1):min(img.shape[1], x2)
]
# If YOLO says it's a dot, just append "dot"
if self.inv_map[cls_idx] == "dot":
digits.append("dot")
else:
# 5) Preprocess and predict with LeNet
inp = self.preprocess(crop)
preds = self.cnn.predict(inp, verbose=0)
idx = int(np.argmax(preds, axis=1)[0])
digits.append(self.inv_map[idx])
return "".join(digits)
def process_dataset_folder(self, dataset_dir, output_file=None):
"""
Process all images in a dataset directory and its subdirectories.
Args:
dataset_dir (str): Path to the dataset directory
output_file (str): Optional path to save results to a text file
Returns:
dict: Dictionary mapping image paths to OCR results
"""
results = {}
processed = 0
# Iterate through all subdirectories
for root, _, files in os.walk(dataset_dir):
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(root, file)
try:
ocr_result = self.ocr_image(image_path)
ocr_result = ocr_result.replace("dot", ".")
results[image_path] = ocr_result
processed += 1
# Print progress
if processed % 100 == 0:
print(f"Processed {processed} images")
except Exception as e:
print(f"Error processing {image_path}: {str(e)}")
# Save results to file if specified
if output_file:
with open(output_file, 'w') as f:
for img_path, result in results.items():
f.write(f"{img_path},{result}\n")
print(f"Results saved to {output_file}")
print(f"Total images processed: {processed}")
return results
# Example usage
if __name__ == "__main__":
ocr = YoloLeNetOCR(
yolo_model_path="Models/res_detect_v4.pt",
lenet_model_path="Models/lenet_res_v4.h5",
conf_threshold=0.3
)
# Process the entire new_data/res folder
results = ocr.process_dataset_folder(
dataset_dir="new_data/res",
output_file="ocr_results.csv"
)
# Print sample results
print("\nSample results:")
sample_count = 0
for path, text in results.items():
print(f"{path}: {text}")
sample_count += 1
if sample_count >= 5:
break
# # -------------------
# # Example usage
# # -------------------
# if __name__ == "__main__":
# ocr = YoloLeNetOCR(
# yolo_model_path="Models/res_detect_v3.pt",
# lenet_model_path="Models/lenet_res_v4.h5",
# conf_threshold=0.3
# )
# result = ocr.ocr_image("new_data/res")
# result = result.replace("dot", ".")
# print("OCR result:", result) |