File size: 4,823 Bytes
981b0ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import cv2
import numpy as np
from ultralytics import YOLO
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

def get_yolo_ocr_xloc(
    img,
    yolo_model,
    ocr_pipeline,
    num_cropped_boxes=5,
    expand_px=1,
    expand_px_for_first_last_cha=16,
    yolo_imgsz=640,
    yolo_iou=0,
    yolo_conf=0.07
):
    """
    Detect character bounding boxes and recognize characters in an image using YOLO and OCR.

    Parameters:
        img_path (str): Path to the input image file.
        yolo_model (YOLO): Instantiated YOLO model for character detection.
        ocr_pipeline (Pipeline): Instantiated ModelScope OCR pipeline for character recognition.
        num_cropped_boxes (int): Number of adjacent boxes to crop for each OCR segment (default: 5).
        expand_px (int): Number of pixels to expand each side of the box for non-edge characters (default: 1).
        expand_px_for_first_last_cha (int): Number of pixels to expand for the first/last character (default: 12).
        yolo_imgsz (int): Image size for YOLO inference (default: 640).
        yolo_iou (float): IOU threshold for YOLO detection (default: 0.1).
        yolo_conf (float): Confidence threshold for YOLO detection (default: 0.07).

    Returns:
        boxes (list of list): List of detected bounding boxes [x1, y1, x2, y2], sorted left-to-right.
        recognized_chars (list of str): List of recognized characters, one per box.
        char_x_centers (list of int): List of x-axis center positions for each character.
    """
    # img = cv2.imread(img_path)
    height, width = img.shape[:2]

    yolo_scale = (width / height // 10) + 1 
    yolo_size = min(int(yolo_imgsz * yolo_scale), 1600)
    results = yolo_model([img], imgsz=yolo_size, iou=yolo_iou, conf=yolo_conf, verbose=False)

    result = results[0]
    boxes = result.boxes.xyxy.cpu().numpy().astype(int)
    boxes = sorted(boxes, key=lambda box: box[0])
    recognized_chars = []
    char_x_centers = []
    n_boxes = len(boxes)

    for j, box in enumerate(boxes):
        if n_boxes <= num_cropped_boxes:
            idxs = list(range(n_boxes))
        else:
            half = num_cropped_boxes // 2
            start = max(0, min(j - half, n_boxes - num_cropped_boxes))
            end = start + num_cropped_boxes
            idxs = list(range(start, end))
        boxes_to_crop = [boxes[idx] for idx in idxs]
        contains_last_char = (n_boxes - 1) in idxs
        if j == 0:
            left_expand = expand_px_for_first_last_cha
        else:
            left_expand = expand_px
        if contains_last_char:
            right_expand = expand_px_for_first_last_cha
        else:
            right_expand = expand_px
        crop_x1 = min(b[0] for b in boxes_to_crop)
        crop_x2 = max(b[2] for b in boxes_to_crop)
        crop_y1 = 0
        crop_y2 = img.shape[0]
        if j == 0:
            crop_x1 = max(crop_x1 - left_expand, 0)
        if contains_last_char:
            crop_x2 = min(crop_x2 + right_expand, img.shape[1])
        segment_img = img[crop_y1:crop_y2, crop_x1:crop_x2].copy()
        mask = np.zeros(segment_img.shape[:2], dtype=np.uint8)
        for b in boxes_to_crop:
            bx1 = max(b[0] - crop_x1 - expand_px, 0)
            bx2 = min(b[2] - crop_x1 + expand_px, crop_x2 - crop_x1)
            by1 = 0
            by2 = img.shape[0]
            mask[by1:by2, bx1:bx2] = 255
        non_text_mask = cv2.bitwise_not(mask)
        if np.count_nonzero(non_text_mask) > 0:
            mean_color = cv2.mean(segment_img, mask=non_text_mask)[:3]
            mean_color = np.array(mean_color, dtype=np.uint8)
        else:
            mean_color = np.array([255, 255, 255], dtype=np.uint8)
        mean_img = np.full(segment_img.shape, mean_color, dtype=np.uint8)
        blurred_mask = cv2.GaussianBlur(mask, (15, 15), 0)
        alpha = blurred_mask.astype(np.float32) / 255.0
        alpha = np.expand_dims(alpha, axis=2)
        segment_img_masked = (segment_img * alpha + mean_img * (1 - alpha)).astype(np.uint8)
        ocr_result = ocr_pipeline(segment_img_masked)
        segment_text = ocr_result['text'][0] if 'text' in ocr_result else ''
        segment_text = segment_text.replace(' ', '')
        if len(segment_text) == num_cropped_boxes:
            char = segment_text[j - idxs[0]]
        elif len(segment_text) > 0:
            char = segment_text[min(j - idxs[0], len(segment_text)-1)]
        else:
            char = ''
        recognized_chars.append(char)
        x1, _, x2, _ = box
        x_center = (x1 + x2) // 2
        char_x_centers.append(x_center)
    
    # if img.ndim == 2:
    #     img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # GGG
    # else:
    #     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # RGB
    return boxes, recognized_chars, char_x_centers