Spaces:
Running
on
Zero
Running
on
Zero
import cv2 | |
import numpy as np | |
from ultralytics import YOLO | |
from modelscope.pipelines import pipeline | |
from modelscope.utils.constant import Tasks | |
def get_yolo_ocr_xloc( | |
img, | |
yolo_model, | |
ocr_pipeline, | |
num_cropped_boxes=5, | |
expand_px=1, | |
expand_px_for_first_last_cha=16, | |
yolo_imgsz=640, | |
yolo_iou=0, | |
yolo_conf=0.07 | |
): | |
""" | |
Detect character bounding boxes and recognize characters in an image using YOLO and OCR. | |
Parameters: | |
img_path (str): Path to the input image file. | |
yolo_model (YOLO): Instantiated YOLO model for character detection. | |
ocr_pipeline (Pipeline): Instantiated ModelScope OCR pipeline for character recognition. | |
num_cropped_boxes (int): Number of adjacent boxes to crop for each OCR segment (default: 5). | |
expand_px (int): Number of pixels to expand each side of the box for non-edge characters (default: 1). | |
expand_px_for_first_last_cha (int): Number of pixels to expand for the first/last character (default: 12). | |
yolo_imgsz (int): Image size for YOLO inference (default: 640). | |
yolo_iou (float): IOU threshold for YOLO detection (default: 0.1). | |
yolo_conf (float): Confidence threshold for YOLO detection (default: 0.07). | |
Returns: | |
boxes (list of list): List of detected bounding boxes [x1, y1, x2, y2], sorted left-to-right. | |
recognized_chars (list of str): List of recognized characters, one per box. | |
char_x_centers (list of int): List of x-axis center positions for each character. | |
""" | |
# img = cv2.imread(img_path) | |
height, width = img.shape[:2] | |
yolo_scale = (width / height // 10) + 1 | |
yolo_size = min(int(yolo_imgsz * yolo_scale), 1600) | |
results = yolo_model([img], imgsz=yolo_size, iou=yolo_iou, conf=yolo_conf, verbose=False) | |
result = results[0] | |
boxes = result.boxes.xyxy.cpu().numpy().astype(int) | |
boxes = sorted(boxes, key=lambda box: box[0]) | |
recognized_chars = [] | |
char_x_centers = [] | |
n_boxes = len(boxes) | |
for j, box in enumerate(boxes): | |
if n_boxes <= num_cropped_boxes: | |
idxs = list(range(n_boxes)) | |
else: | |
half = num_cropped_boxes // 2 | |
start = max(0, min(j - half, n_boxes - num_cropped_boxes)) | |
end = start + num_cropped_boxes | |
idxs = list(range(start, end)) | |
boxes_to_crop = [boxes[idx] for idx in idxs] | |
contains_last_char = (n_boxes - 1) in idxs | |
if j == 0: | |
left_expand = expand_px_for_first_last_cha | |
else: | |
left_expand = expand_px | |
if contains_last_char: | |
right_expand = expand_px_for_first_last_cha | |
else: | |
right_expand = expand_px | |
crop_x1 = min(b[0] for b in boxes_to_crop) | |
crop_x2 = max(b[2] for b in boxes_to_crop) | |
crop_y1 = 0 | |
crop_y2 = img.shape[0] | |
if j == 0: | |
crop_x1 = max(crop_x1 - left_expand, 0) | |
if contains_last_char: | |
crop_x2 = min(crop_x2 + right_expand, img.shape[1]) | |
segment_img = img[crop_y1:crop_y2, crop_x1:crop_x2].copy() | |
mask = np.zeros(segment_img.shape[:2], dtype=np.uint8) | |
for b in boxes_to_crop: | |
bx1 = max(b[0] - crop_x1 - expand_px, 0) | |
bx2 = min(b[2] - crop_x1 + expand_px, crop_x2 - crop_x1) | |
by1 = 0 | |
by2 = img.shape[0] | |
mask[by1:by2, bx1:bx2] = 255 | |
non_text_mask = cv2.bitwise_not(mask) | |
if np.count_nonzero(non_text_mask) > 0: | |
mean_color = cv2.mean(segment_img, mask=non_text_mask)[:3] | |
mean_color = np.array(mean_color, dtype=np.uint8) | |
else: | |
mean_color = np.array([255, 255, 255], dtype=np.uint8) | |
mean_img = np.full(segment_img.shape, mean_color, dtype=np.uint8) | |
blurred_mask = cv2.GaussianBlur(mask, (15, 15), 0) | |
alpha = blurred_mask.astype(np.float32) / 255.0 | |
alpha = np.expand_dims(alpha, axis=2) | |
segment_img_masked = (segment_img * alpha + mean_img * (1 - alpha)).astype(np.uint8) | |
ocr_result = ocr_pipeline(segment_img_masked) | |
segment_text = ocr_result['text'][0] if 'text' in ocr_result else '' | |
segment_text = segment_text.replace(' ', '') | |
if len(segment_text) == num_cropped_boxes: | |
char = segment_text[j - idxs[0]] | |
elif len(segment_text) > 0: | |
char = segment_text[min(j - idxs[0], len(segment_text)-1)] | |
else: | |
char = '' | |
recognized_chars.append(char) | |
x1, _, x2, _ = box | |
x_center = (x1 + x2) // 2 | |
char_x_centers.append(x_center) | |
# if img.ndim == 2: | |
# img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG | |
# else: | |
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB | |
return boxes, recognized_chars, char_x_centers |