marconetplusplus / utils /yolo_ocr_xloc.py
csxmli's picture
Upload
981b0ab verified
import cv2
import numpy as np
from ultralytics import YOLO
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
def get_yolo_ocr_xloc(
img,
yolo_model,
ocr_pipeline,
num_cropped_boxes=5,
expand_px=1,
expand_px_for_first_last_cha=16,
yolo_imgsz=640,
yolo_iou=0,
yolo_conf=0.07
):
"""
Detect character bounding boxes and recognize characters in an image using YOLO and OCR.
Parameters:
img_path (str): Path to the input image file.
yolo_model (YOLO): Instantiated YOLO model for character detection.
ocr_pipeline (Pipeline): Instantiated ModelScope OCR pipeline for character recognition.
num_cropped_boxes (int): Number of adjacent boxes to crop for each OCR segment (default: 5).
expand_px (int): Number of pixels to expand each side of the box for non-edge characters (default: 1).
expand_px_for_first_last_cha (int): Number of pixels to expand for the first/last character (default: 12).
yolo_imgsz (int): Image size for YOLO inference (default: 640).
yolo_iou (float): IOU threshold for YOLO detection (default: 0.1).
yolo_conf (float): Confidence threshold for YOLO detection (default: 0.07).
Returns:
boxes (list of list): List of detected bounding boxes [x1, y1, x2, y2], sorted left-to-right.
recognized_chars (list of str): List of recognized characters, one per box.
char_x_centers (list of int): List of x-axis center positions for each character.
"""
# img = cv2.imread(img_path)
height, width = img.shape[:2]
yolo_scale = (width / height // 10) + 1
yolo_size = min(int(yolo_imgsz * yolo_scale), 1600)
results = yolo_model([img], imgsz=yolo_size, iou=yolo_iou, conf=yolo_conf, verbose=False)
result = results[0]
boxes = result.boxes.xyxy.cpu().numpy().astype(int)
boxes = sorted(boxes, key=lambda box: box[0])
recognized_chars = []
char_x_centers = []
n_boxes = len(boxes)
for j, box in enumerate(boxes):
if n_boxes <= num_cropped_boxes:
idxs = list(range(n_boxes))
else:
half = num_cropped_boxes // 2
start = max(0, min(j - half, n_boxes - num_cropped_boxes))
end = start + num_cropped_boxes
idxs = list(range(start, end))
boxes_to_crop = [boxes[idx] for idx in idxs]
contains_last_char = (n_boxes - 1) in idxs
if j == 0:
left_expand = expand_px_for_first_last_cha
else:
left_expand = expand_px
if contains_last_char:
right_expand = expand_px_for_first_last_cha
else:
right_expand = expand_px
crop_x1 = min(b[0] for b in boxes_to_crop)
crop_x2 = max(b[2] for b in boxes_to_crop)
crop_y1 = 0
crop_y2 = img.shape[0]
if j == 0:
crop_x1 = max(crop_x1 - left_expand, 0)
if contains_last_char:
crop_x2 = min(crop_x2 + right_expand, img.shape[1])
segment_img = img[crop_y1:crop_y2, crop_x1:crop_x2].copy()
mask = np.zeros(segment_img.shape[:2], dtype=np.uint8)
for b in boxes_to_crop:
bx1 = max(b[0] - crop_x1 - expand_px, 0)
bx2 = min(b[2] - crop_x1 + expand_px, crop_x2 - crop_x1)
by1 = 0
by2 = img.shape[0]
mask[by1:by2, bx1:bx2] = 255
non_text_mask = cv2.bitwise_not(mask)
if np.count_nonzero(non_text_mask) > 0:
mean_color = cv2.mean(segment_img, mask=non_text_mask)[:3]
mean_color = np.array(mean_color, dtype=np.uint8)
else:
mean_color = np.array([255, 255, 255], dtype=np.uint8)
mean_img = np.full(segment_img.shape, mean_color, dtype=np.uint8)
blurred_mask = cv2.GaussianBlur(mask, (15, 15), 0)
alpha = blurred_mask.astype(np.float32) / 255.0
alpha = np.expand_dims(alpha, axis=2)
segment_img_masked = (segment_img * alpha + mean_img * (1 - alpha)).astype(np.uint8)
ocr_result = ocr_pipeline(segment_img_masked)
segment_text = ocr_result['text'][0] if 'text' in ocr_result else ''
segment_text = segment_text.replace(' ', '')
if len(segment_text) == num_cropped_boxes:
char = segment_text[j - idxs[0]]
elif len(segment_text) > 0:
char = segment_text[min(j - idxs[0], len(segment_text)-1)]
else:
char = ''
recognized_chars.append(char)
x1, _, x2, _ = box
x_center = (x1 + x2) // 2
char_x_centers.append(x_center)
# if img.ndim == 2:
# img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
# else:
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
return boxes, recognized_chars, char_x_centers