# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import imghdr import cv2 import random import numpy as np import paddle def print_dict(d, logger, delimiter=0): """ Recursively visualize a dict and indenting acrrording by the relationship of keys. """ for k, v in sorted(d.items()): if isinstance(v, dict): logger.info("{}{} : ".format(delimiter * " ", str(k))) print_dict(v, logger, delimiter + 4) elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict): logger.info("{}{} : ".format(delimiter * " ", str(k))) for value in v: print_dict(value, logger, delimiter + 4) else: logger.info("{}{} : {}".format(delimiter * " ", k, v)) def get_check_global_params(mode): check_params = ['use_gpu', 'max_text_length', 'image_shape', \ 'image_shape', 'character_type', 'loss_type'] if mode == "train_eval": check_params = check_params + [ \ 'train_batch_size_per_card', 'test_batch_size_per_card'] elif mode == "test": check_params = check_params + ['test_batch_size_per_card'] return check_params def _check_image_file(path): img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'} return any([path.lower().endswith(e) for e in img_end]) def get_image_file_list(img_file): imgs_lists = [] if img_file is None or not os.path.exists(img_file): raise Exception("not found any img file in {}".format(img_file)) if os.path.isfile(img_file) and _check_image_file(img_file): imgs_lists.append(img_file) elif os.path.isdir(img_file): for single_file in os.listdir(img_file): file_path = os.path.join(img_file, single_file) if os.path.isfile(file_path) and _check_image_file(file_path): imgs_lists.append(file_path) if len(imgs_lists) == 0: raise Exception("not found any img file in {}".format(img_file)) imgs_lists = sorted(imgs_lists) return imgs_lists def check_and_read(img_path): if os.path.basename(img_path)[-3:].lower() == 'gif': gif = cv2.VideoCapture(img_path) ret, frame = gif.read() if not ret: logger = logging.getLogger('ppocr') logger.info("Cannot read {}. This gif image maybe corrupted.") return None, False if len(frame.shape) == 2 or frame.shape[-1] == 1: frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) imgvalue = frame[:, :, ::-1] return imgvalue, True, False elif os.path.basename(img_path)[-3:].lower() == 'pdf': import fitz from PIL import Image imgs = [] with fitz.open(img_path) as pdf: for pg in range(0, pdf.page_count): page = pdf[pg] mat = fitz.Matrix(2, 2) pm = page.get_pixmap(matrix=mat, alpha=False) # if width or height > 2000 pixels, don't enlarge the image if pm.width > 2000 or pm.height > 2000: pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples) img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) imgs.append(img) return imgs, False, True return None, False, False def load_vqa_bio_label_maps(label_map_path): with open(label_map_path, "r", encoding='utf-8') as fin: lines = fin.readlines() old_lines = [line.strip() for line in lines] lines = ["O"] for line in old_lines: # "O" has already been in lines if line.upper() in ["OTHER", "OTHERS", "IGNORE"]: continue lines.append(line) labels = ["O"] for line in lines[1:]: labels.append("B-" + line) labels.append("I-" + line) label2id_map = {label.upper(): idx for idx, label in enumerate(labels)} id2label_map = {idx: label.upper() for idx, label in enumerate(labels)} return label2id_map, id2label_map def set_seed(seed=1024): random.seed(seed) np.random.seed(seed) paddle.seed(seed) class AverageMeter: def __init__(self): self.reset() def reset(self): """reset""" self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): """update""" self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count