import io import cv2 import numpy as np from PIL import Image from .abinet_label_encode import ABINetLabelEncode from .ar_label_encode import ARLabelEncode from .ce_label_encode import CELabelEncode from .char_label_encode import CharLabelEncode from .cppd_label_encode import CPPDLabelEncode from .ctc_label_encode import CTCLabelEncode from .ep_label_encode import EPLabelEncode from .igtr_label_encode import IGTRLabelEncode from .mgp_label_encode import MGPLabelEncode from .rec_aug import ABINetAug from .rec_aug import BaseDataAugmentation as BDA from .rec_aug import PARSeqAug, PARSeqAugPIL, SVTRAug from .resize import (ABINetResize, CDistNetResize, LongResize, RecTVResize, RobustScannerRecResizeImg, SliceResize, SliceTVResize, SRNRecResizeImg, SVTRResize, VisionLANResize, RecDynamicResize) from .smtr_label_encode import SMTRLabelEncode from .srn_label_encode import SRNLabelEncode from .visionlan_label_encode import VisionLANLabelEncode from .cam_label_encode import CAMLabelEncode class KeepKeys(object): def __init__(self, keep_keys, **kwargs): self.keep_keys = keep_keys def __call__(self, data): data_list = [] for key in self.keep_keys: data_list.append(data[key]) return data_list def transform(data, ops=None): """transform.""" if ops is None: ops = [] for op in ops: data = op(data) if data is None: return None return data class Fasttext(object): def __init__(self, path='None', **kwargs): # pip install fasttext==0.9.1 import fasttext self.fast_model = fasttext.load_model(path) def __call__(self, data): label = data['label'] fast_label = self.fast_model[label] data['fast_label'] = fast_label return data class DecodeImage(object): """decode image.""" def __init__(self, img_mode='RGB', channel_first=False, ignore_orientation=False, **kwargs): self.img_mode = img_mode self.channel_first = channel_first self.ignore_orientation = ignore_orientation def __call__(self, data): img = data['image'] assert type(img) is bytes and len( img) > 0, "invalid input 'img' in DecodeImage" img = np.frombuffer(img, dtype='uint8') if self.ignore_orientation: img = cv2.imdecode( img, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR) else: img = cv2.imdecode(img, 1) if img is None: return None if self.img_mode == 'GRAY': img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) elif self.img_mode == 'RGB': assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( img.shape) img = img[:, :, ::-1] if self.channel_first: img = img.transpose((2, 0, 1)) data['image'] = img return data class DecodeImagePIL(object): """decode image.""" def __init__(self, img_mode='RGB', **kwargs): self.img_mode = img_mode def __call__(self, data): img = data['image'] assert type(img) is bytes and len( img) > 0, "invalid input 'img' in DecodeImage" img = data['image'] buf = io.BytesIO(img) img = Image.open(buf).convert('RGB') if self.img_mode == 'Gray': img = img.convert('L') elif self.img_mode == 'BGR': img = np.array(img)[:, :, ::-1] # 将图片转为numpy格式,并将最后一维通道倒序 img = Image.fromarray(np.uint8(img)) data['image'] = img return data def create_operators(op_param_list, global_config=None): """create operators based on the config. Args: params(list): a dict list, used to create some operators """ assert isinstance(op_param_list, list), 'operator config should be a list' ops = [] for operator in op_param_list: assert isinstance(operator, dict) and len(operator) == 1, 'yaml format error' op_name = list(operator)[0] param = {} if operator[op_name] is None else operator[op_name] if global_config is not None: param.update(global_config) op = eval(op_name)(**param) ops.append(op) return ops class GTCLabelEncode(): """Convert between text-label and text-index.""" def __init__(self, gtc_label_encode, max_text_length, character_dict_path=None, use_space_char=False, **kwargs): self.gtc_label_encode = eval(gtc_label_encode['name'])( max_text_length=max_text_length, character_dict_path=character_dict_path, use_space_char=use_space_char, **gtc_label_encode) self.ctc_label_encode = CTCLabelEncode(max_text_length, character_dict_path, use_space_char) def __call__(self, data): data_ctc = self.ctc_label_encode({'label': data['label']}) data = self.gtc_label_encode(data) if data_ctc is None or data is None: return None data['ctc_label'] = data_ctc['label'] data['ctc_length'] = data_ctc['length'] return data