App / main.py
LinhKL2002's picture
Upload folder using huggingface_hub
4dbe5d1 verified
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import copy
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
import os
from cal_rec_boxes import CalRecBoxes
from ch_ppocr_cls import TextClassifier
from ch_ppocr_det import TextDetector
from ch_ppocr_rec import TextRecognizer
from utils import (
LoadImage,
UpdateParameters,
VisRes,
add_round_letterbox,
get_logger,
increase_min_side,
init_args,
read_yaml,
reduce_max_side,
update_model_path,
)
root_dir = Path(__file__).resolve().parent
DEFAULT_CFG_PATH = root_dir / "config.yaml"
logger = get_logger("RapidOCR")
class RapidOCR:
def __init__(self, config_path: Optional[str] = None, **kwargs):
if config_path is not None and Path(config_path).exists():
config = read_yaml(config_path)
else:
config = read_yaml(DEFAULT_CFG_PATH)
config = update_model_path(config)
if kwargs:
updater = UpdateParameters()
config = updater(config, **kwargs)
global_config = config["Global"]
self.print_verbose = global_config["print_verbose"]
self.text_score = global_config["text_score"]
self.min_height = global_config["min_height"]
self.width_height_ratio = global_config["width_height_ratio"]
self.use_det = global_config["use_det"]
self.text_det = TextDetector(config["Det"])
# self.use_cls = global_config["use_cls"]
# self.text_cls = TextClassifier(config["Cls"])
self.use_rec = global_config["use_rec"]
self.text_rec = TextRecognizer(config["Rec"])
self.load_img = LoadImage()
self.max_side_len = global_config["max_side_len"]
self.min_side_len = global_config["min_side_len"]
self.cal_rec_boxes = CalRecBoxes()
def __call__(
self,
img_content: Union[str, np.ndarray, bytes, Path],
use_det: Optional[bool] = None,
use_cls: Optional[bool] = None,
use_rec: Optional[bool] = None,
**kwargs,
) -> Tuple[Optional[List[List[Union[Any, str]]]], Optional[List[float]]]:
use_det = self.use_det if use_det is None else use_det
use_cls = self.use_cls if use_cls is None else use_cls
use_rec = self.use_rec if use_rec is None else use_rec
return_word_box = False
if kwargs:
box_thresh = kwargs.get("box_thresh", 0.5)
unclip_ratio = kwargs.get("unclip_ratio", 1.6)
text_score = kwargs.get("text_score", 0.5)
return_word_box = kwargs.get("return_word_box", False)
self.text_det.postprocess_op.box_thresh = box_thresh
self.text_det.postprocess_op.unclip_ratio = unclip_ratio
self.text_score = text_score
img = self.load_img(img_content)
raw_h, raw_w = img.shape[:2]
op_record = {}
img, ratio_h, ratio_w = self.preprocess(img)
op_record["preprocess"] = {"ratio_h": ratio_h, "ratio_w": ratio_w}
dt_boxes, cls_res, rec_res = None, None, None
det_elapse, cls_elapse, rec_elapse = 0.0, 0.0, 0.0
if use_det:
img, op_record = self.maybe_add_letterbox(img, op_record)
dt_boxes, det_elapse = self.auto_text_det(img)
if dt_boxes is None:
return None, None
img = self.get_crop_img_list(img, dt_boxes)
# if use_cls:
# img, cls_res, cls_elapse = self.text_cls(img)
if use_rec:
rec_res, rec_elapse = self.text_rec(img, return_word_box)
if dt_boxes is not None and rec_res is not None and return_word_box:
rec_res = self.cal_rec_boxes(img, dt_boxes, rec_res)
for rec_res_i in rec_res:
if rec_res_i[2]:
rec_res_i[2] = (
self._get_origin_points(rec_res_i[2], op_record, raw_h, raw_w)
.astype(np.int32)
.tolist()
)
if dt_boxes is not None:
dt_boxes = self._get_origin_points(dt_boxes, op_record, raw_h, raw_w)
ocr_res = self.get_final_res(
dt_boxes, cls_res, rec_res, det_elapse, cls_elapse, rec_elapse
)
return ocr_res
def preprocess(self, img: np.ndarray) -> Tuple[np.ndarray, float, float]:
h, w = img.shape[:2]
max_value = max(h, w)
ratio_h = ratio_w = 1.0
if max_value > self.max_side_len:
img, ratio_h, ratio_w = reduce_max_side(img, self.max_side_len)
h, w = img.shape[:2]
min_value = min(h, w)
if min_value < self.min_side_len:
img, ratio_h, ratio_w = increase_min_side(img, self.min_side_len)
return img, ratio_h, ratio_w
def maybe_add_letterbox(
self, img: np.ndarray, op_record: Dict[str, Any]
) -> Tuple[np.ndarray, Dict[str, Any]]:
h, w = img.shape[:2]
if self.width_height_ratio == -1:
use_limit_ratio = False
else:
use_limit_ratio = w / h > self.width_height_ratio
if h <= self.min_height or use_limit_ratio:
padding_h = self._get_padding_h(h, w)
block_img = add_round_letterbox(img, (padding_h, padding_h, 0, 0))
op_record["padding_1"] = {"top": padding_h, "left": 0}
return block_img, op_record
op_record["padding_1"] = {"top": 0, "left": 0}
return img, op_record
def _get_padding_h(self, h: int, w: int) -> int:
new_h = max(int(w / self.width_height_ratio), self.min_height) * 2
padding_h = int(abs(new_h - h) / 2)
return padding_h
def auto_text_det(
self, img: np.ndarray
) -> Tuple[Optional[List[np.ndarray]], float]:
dt_boxes, det_elapse = self.text_det(img)
if dt_boxes is None or len(dt_boxes) < 1:
return None, 0.0
dt_boxes = self.sorted_boxes(dt_boxes)
return dt_boxes, det_elapse
def get_crop_img_list(
self, img: np.ndarray, dt_boxes: List[np.ndarray]
) -> List[np.ndarray]:
def get_rotate_crop_image(img: np.ndarray, points: np.ndarray) -> np.ndarray:
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3]),
)
)
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2]),
)
)
pts_std = np.array(
[
[0, 0],
[img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height],
]
).astype(np.float32)
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M,
(img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC,
)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
img_crop_list = []
for box in dt_boxes:
tmp_box = copy.deepcopy(box)
img_crop = get_rotate_crop_image(img, tmp_box)
img_crop_list.append(img_crop)
return img_crop_list
@staticmethod
def sorted_boxes(dt_boxes: np.ndarray) -> List[np.ndarray]:
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if (
abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10
and _boxes[j + 1][0][0] < _boxes[j][0][0]
):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
def _get_origin_points(
self,
dt_boxes: List[np.ndarray],
op_record: Dict[str, Any],
raw_h: int,
raw_w: int,
) -> np.ndarray:
dt_boxes_array = np.array(dt_boxes).astype(np.float32)
for op in reversed(list(op_record.keys())):
v = op_record[op]
if "padding" in op:
top, left = v.get("top"), v.get("left")
dt_boxes_array[:, :, 0] -= left
dt_boxes_array[:, :, 1] -= top
elif "preprocess" in op:
ratio_h = v.get("ratio_h")
ratio_w = v.get("ratio_w")
dt_boxes_array[:, :, 0] *= ratio_w
dt_boxes_array[:, :, 1] *= ratio_h
dt_boxes_array = np.where(dt_boxes_array < 0, 0, dt_boxes_array)
dt_boxes_array[..., 0] = np.where(
dt_boxes_array[..., 0] > raw_w, raw_w, dt_boxes_array[..., 0]
)
dt_boxes_array[..., 1] = np.where(
dt_boxes_array[..., 1] > raw_h, raw_h, dt_boxes_array[..., 1]
)
return dt_boxes_array
def get_final_res(
self,
dt_boxes: Optional[List[np.ndarray]],
cls_res: Optional[List[List[Union[str, float]]]],
rec_res: Optional[List[Tuple[str, float, List[Union[str, float]]]]],
det_elapse: float,
cls_elapse: float,
rec_elapse: float,
) -> Tuple[Optional[List[List[Union[Any, str]]]], Optional[List[float]]]:
if dt_boxes is None and rec_res is None and cls_res is not None:
return cls_res, [cls_elapse]
if dt_boxes is None and rec_res is None:
return None, None
if dt_boxes is None and rec_res is not None:
return [[res[0], res[1]] for res in rec_res], [rec_elapse]
if dt_boxes is not None and rec_res is None:
return [box.tolist() for box in dt_boxes], [det_elapse]
dt_boxes, rec_res = self.filter_result(dt_boxes, rec_res)
if not dt_boxes or not rec_res or len(dt_boxes) <= 0:
return None, None
ocr_res = [[box.tolist(), *res] for box, res in zip(dt_boxes, rec_res)], [
det_elapse,
cls_elapse,
rec_elapse,
]
return ocr_res
def filter_result(
self,
dt_boxes: Optional[List[np.ndarray]],
rec_res: Optional[List[Tuple[str, float]]],
) -> Tuple[Optional[List[np.ndarray]], Optional[List[Tuple[str, float]]]]:
if dt_boxes is None or rec_res is None:
return None, None
filter_boxes, filter_rec_res = [], []
for box, rec_reuslt in zip(dt_boxes, rec_res):
text, score = rec_reuslt[0], rec_reuslt[1]
if float(score) >= self.text_score:
filter_boxes.append(box)
filter_rec_res.append(rec_reuslt)
return filter_boxes, filter_rec_res
def main():
args = init_args()
ocr_engine = RapidOCR(**vars(args))
use_det = not args.no_det
use_cls = not args.no_cls
use_rec = not args.no_rec
result, elapse_list = ocr_engine(
args.img_path, use_det=use_det, use_cls=use_cls, use_rec=use_rec, **vars(args)
)
logger.info(result)
# Save the recognized text to a text file in the 'results' folder
if use_det and use_rec:
boxes, txts, scores = list(zip(*result))
# Create the 'results' folder if it doesn't exist
results_folder = Path("results")
results_folder.mkdir(parents=True, exist_ok=True)
# Create the file path for saving the text in 'results' folder
img_name = os.path.splitext(os.path.basename(args.img_path))[0] # Get the image name without extension
txt_file_path = results_folder / f"{img_name}.txt" # Save in 'results' folder
# Write the recognized text to the text file
with open(txt_file_path, 'w', encoding='utf-8') as f:
for txt in txts:
f.write(txt + '\n')
logger.info("The recognized text has been saved in %s", txt_file_path)
if args.print_cost:
logger.info(elapse_list)
if args.vis_res:
vis = VisRes()
Path(args.vis_save_path).mkdir(parents=True, exist_ok=True)
save_path = Path(args.vis_save_path) / f"{Path(args.img_path).stem}_vis.png"
if use_det and not use_cls and not use_rec:
boxes, *_ = list(zip(*result))
vis_img = vis(args.img_path, boxes)
cv2.imwrite(str(save_path), vis_img)
logger.info("The vis result has saved in %s", save_path)
elif use_det and use_rec:
font_path = Path(args.vis_font_path)
if not font_path.exists():
raise FileExistsError(f"{font_path} does not exist!")
boxes, txts, scores = list(zip(*result))
vis_img = vis(args.img_path, boxes, txts, scores, font_path=font_path)
cv2.imwrite(str(save_path), vis_img)
logger.info("The vis result has saved in %s", save_path)
if __name__ == "__main__":
main()