import os import json import random from PIL import Image, ImageDraw from pdf_extract_kit.registry.registry import TASK_REGISTRY from pdf_extract_kit.utils.data_preprocess import load_pdf from pdf_extract_kit.tasks.base_task import BaseTask @TASK_REGISTRY.register("ocr") class OCRTask(BaseTask): def __init__(self, model): """init the task based on the given model. Args: model: task model, must contains predict function. """ super().__init__(model) def predict_image(self, image): """predict on one image, reture text detection and recognition results. Args: image: PIL.Image.Image, (if the model.predict function support other types, remenber add change-format-function in model.predict) Returns: List[dict]: list of text bbox with it's content Return example: [ { "category_type": "text", "poly": [ 380.6792698635707, 159.85058512958923, 765.1419999999998, 159.85058512958923, 765.1419999999998, 192.51073013642917, 380.6792698635707, 192.51073013642917 ], "text": "this is an example text", "score": 0.97 }, ... ] """ return self.model.predict(image) def prepare_input_files(self, input_path): if os.path.isdir(input_path): file_list = [os.path.join(input_path, fname) for fname in os.listdir(input_path)] else: file_list = [input_path] return file_list def process(self, input_path, save_dir=None, visualize=False): file_list = self.prepare_input_files(input_path) res_list = [] for fpath in file_list: basename = os.path.basename(fpath)[:-4] if fpath.endswith(".pdf") or fpath.endswith(".PDF"): images = load_pdf(fpath) pdf_res = [] for page, img in enumerate(images): page_res = self.predict_image(img) pdf_res.append(page_res) if save_dir: os.makedirs(os.path.join(save_dir, basename), exist_ok=True) self.save_json_result(page_res, os.path.join(save_dir, basename, f"page_{page+1}.json")) if visualize: self.visualize_image(img, page_res, os.path.join(save_dir, basename, f"page_{page+1}.jpg")) res_list.append(pdf_res) else: image = Image.open(fpath) img_res = self.predict_image(image) res_list.append(img_res) if save_dir: os.makedirs(save_dir, exist_ok=True) self.save_json_result(img_res, os.path.join(save_dir, f"{basename}.json")) if visualize: self.visualize_image(image, img_res, os.path.join(save_dir, f"{basename}.png")) return res_list def visualize_image(self, image, ocr_res, save_path="", cate2color={}): """plot each result's bbox and category on image. Args: image: PIL.Image.Image ocr_res: list of ocr det and rec, whose format following the results of self.predict_image function save_path: path to save visualized image """ draw = ImageDraw.Draw(image) for res in ocr_res: box_color = cate2color.get(res['category_type'], (0, 255, 0)) x_min, y_min = int(res['poly'][0]), int(res['poly'][1]) x_max, y_max = int(res['poly'][4]), int(res['poly'][5]) draw.rectangle([x_min, y_min, x_max, y_max], fill=None, outline=box_color, width=1) draw.text((x_min, y_min), res['category_type'], (255, 0, 0)) if save_path: image.save(save_path) def save_json_result(self, ocr_res, save_path): """save results to a json file. Args: ocr_res: list of ocr det and rec, whose format following the results of self.predict_image function save_path: path to save visualized image """ with open(save_path, "w", encoding="utf-8") as f: f.write(json.dumps(ocr_res, indent=2, ensure_ascii=False))