diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ca8c663035ea187bae3e9827f25579cb1034dd6
--- /dev/null
+++ b/app.py
@@ -0,0 +1,127 @@
+import os
+import gradio as gr # gradio==4.20.0
+
+os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
+import cv2
+import numpy as np
+import json
+import time
+from PIL import Image
+from tools.infer_e2e import OpenOCR, check_and_download_font, draw_ocr_box_txt
+
+drop_score = 0.01
+text_sys = OpenOCR(drop_score=drop_score)
+# warm up 5 times
+if True:
+ img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
+ for i in range(5):
+ res = text_sys(img_numpy=img)
+font_path = './simfang.ttf'
+check_and_download_font(font_path)
+
+
+def main(input_image):
+ img = input_image[:, :, ::-1]
+ starttime = time.time()
+ results, time_dict, mask = text_sys(img_numpy=img, return_mask=True)
+ elapse = time.time() - starttime
+ save_pred = json.dumps(results[0], ensure_ascii=False)
+ image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+ boxes = [res['points'] for res in results[0]]
+ txts = [res['transcription'] for res in results[0]]
+ scores = [res['score'] for res in results[0]]
+ draw_img = draw_ocr_box_txt(
+ image,
+ boxes,
+ txts,
+ scores,
+ drop_score=drop_score,
+ font_path=font_path,
+ )
+ mask = mask[0, 0, :, :] > 0.3
+ return save_pred, elapse, draw_img, mask.astype('uint8') * 255
+
+
+def get_all_file_names_including_subdirs(dir_path):
+ all_file_names = []
+
+ for root, dirs, files in os.walk(dir_path):
+ for file_name in files:
+ all_file_names.append(os.path.join(root, file_name))
+
+ file_names_only = [os.path.basename(file) for file in all_file_names]
+ return file_names_only
+
+
+def list_image_paths(directory):
+ image_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff')
+
+ image_paths = []
+ for root, dirs, files in os.walk(directory):
+ for file in files:
+ if file.lower().endswith(image_extensions):
+ relative_path = os.path.relpath(os.path.join(root, file),
+ directory)
+ full_path = os.path.join(directory, relative_path)
+ image_paths.append(full_path)
+ image_paths = sorted(image_paths)
+ return image_paths
+
+
+def find_file_in_current_dir_and_subdirs(file_name):
+ for root, dirs, files in os.walk('.'):
+ if file_name in files:
+ relative_path = os.path.join(root, file_name)
+ return relative_path
+
+
+def predict1(input_image, Model_type, OCR_type):
+ if OCR_type == 'E2E':
+ return 11111, 'E2E', input_image
+ elif OCR_type == 'STR':
+ return 11111, 'STR', input_image
+ else:
+ return 11111, 'STD', input_image
+
+
+e2e_img_example = list_image_paths('./OCR_e2e_img')
+
+if __name__ == '__main__':
+ css = '.image-container img { width: 100%; max-height: 320px;}'
+
+ with gr.Blocks(css=css) as demo:
+ gr.HTML("""
+
OpenOCR
""")
+ with gr.Row():
+ with gr.Column(scale=1):
+ input_image = gr.Image(label='Input image',
+ elem_classes=['image-container'])
+
+ examples = gr.Examples(examples=e2e_img_example,
+ inputs=input_image,
+ label='Examples')
+ downstream = gr.Button('Run')
+
+ with gr.Column(scale=1):
+ img_mask = gr.Image(label='mask',
+ interactive=False,
+ elem_classes=['image-container'])
+ img_output = gr.Image(label=' ',
+ interactive=False,
+ elem_classes=['image-container'])
+
+ output = gr.Textbox(label='Result')
+ confidence = gr.Textbox(label='Latency')
+
+ downstream.click(fn=main,
+ inputs=[
+ input_image,
+ ],
+ outputs=[
+ output,
+ confidence,
+ img_output,
+ img_mask,
+ ])
+
+ demo.launch(share=True)
diff --git a/configs/det/dbnet/repvit_db.yml b/configs/det/dbnet/repvit_db.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c9b1bc19f8cebc5d14865513fa657a084380847c
--- /dev/null
+++ b/configs/det/dbnet/repvit_db.yml
@@ -0,0 +1,173 @@
+Global:
+ device: gpu
+ epoch_num: &epoch_num 500
+ log_smooth_window: 20
+ print_batch_step: 100
+ save_model_dir: ./output/det_repsvtr_db
+ save_epoch_step: 10
+ eval_batch_step:
+ - 0
+ - 1000
+ cal_metric_during_train: false
+ checkpoints:
+ pretrained_model: openocr_det_repvit_ch.pth
+ save_inference_dir: null
+ use_visualdl: false
+ infer_img: ./testA
+ save_res_path: ./checkpoints/det_db/predicts_db.txt
+ distributed: true
+ model_type: det
+
+Architecture:
+ algorithm: DB
+ Backbone:
+ name: RepSVTR_det
+ Neck:
+ name: RSEFPN
+ out_channels: 96
+ shortcut: True
+ Head:
+ name: DBHead
+ k: 50
+
+# Loss:
+# name: DBLoss
+# balance_loss: true
+# main_loss_type: DiceLoss
+# alpha: 5
+# beta: 10
+# ohem_ratio: 3
+
+# Optimizer:
+# name: Adam
+# beta1: 0.9
+# beta2: 0.999
+# lr:
+# name: Cosine
+# learning_rate: 0.001 #(8*8c)
+# warmup_epoch: 2
+# regularizer:
+# name: L2
+# factor: 5.0e-05
+
+PostProcess:
+ name: DBPostProcess
+ thresh: 0.3
+ box_thresh: 0.4
+ max_candidates: 1000
+ unclip_ratio: 1.5
+ score_mode: 'slow'
+
+# Metric:
+# name: DetMetric
+# main_indicator: hmean
+
+# Train:
+# dataset:
+# name: SimpleDataSet
+# data_dir: ./train_data/icdar2015/text_localization/
+# label_file_list:
+# - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
+# ratio_list: [1.0]
+# transforms:
+# - DecodeImage:
+# img_mode: BGR
+# channel_first: false
+# - DetLabelEncode: null
+# - CopyPaste: null
+# - IaaAugment:
+# augmenter_args:
+# - type: Fliplr
+# args:
+# p: 0.5
+# - type: Affine
+# args:
+# rotate:
+# - -10
+# - 10
+# - type: Resize
+# args:
+# size:
+# - 0.5
+# - 3
+# - EastRandomCropData:
+# size:
+# - 640
+# - 640
+# max_tries: 50
+# keep_ratio: true
+# - MakeBorderMap:
+# shrink_ratio: 0.4
+# thresh_min: 0.3
+# thresh_max: 0.7
+# total_epoch: *epoch_num
+# - MakeShrinkMap:
+# shrink_ratio: 0.4
+# min_text_size: 8
+# total_epoch: *epoch_num
+# - NormalizeImage:
+# scale: 1./255.
+# mean:
+# - 0.485
+# - 0.456
+# - 0.406
+# std:
+# - 0.229
+# - 0.224
+# - 0.225
+# order: hwc
+# - ToCHWImage: null
+# - KeepKeys:
+# keep_keys:
+# - image
+# - threshold_map
+# - threshold_mask
+# - shrink_map
+# - shrink_mask
+# loader:
+# shuffle: true
+# drop_last: false
+# batch_size_per_card: 8
+# num_workers: 8
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/icdar2015/text_localization/
+ label_file_list:
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - DetLabelEncode: null
+ - DetResizeForTest:
+ # image_shape: [1280, 1280]
+ # keep_ratio: True
+ # padding: True
+ limit_side_len: 960
+ limit_type: max
+ - NormalizeImage:
+ scale: 1./255.
+ mean:
+ - 0.485
+ - 0.456
+ - 0.406
+ std:
+ - 0.229
+ - 0.224
+ - 0.225
+ order: hwc
+ - ToCHWImage: null
+ - KeepKeys:
+ keep_keys:
+ - image
+ - shape
+ - polys
+ - ignore_tags
+ loader:
+ shuffle: false
+ drop_last: false
+ batch_size_per_card: 1
+ num_workers: 2
+profiler_options: null
diff --git a/configs/rec/abinet/resnet45_trans_abinet_lang.yml b/configs/rec/abinet/resnet45_trans_abinet_lang.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ce2b2994b9b580a283404e2fae0ddc1a266b09ed
--- /dev/null
+++ b/configs/rec/abinet/resnet45_trans_abinet_lang.yml
@@ -0,0 +1,94 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet45_trans_abinet_lang/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ # ./openocr_nolang_abinet_lang.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_resnet45_trans_abinet_lang.txt
+ grad_clip_val: 20
+ use_amp: True
+
+Optimizer:
+ name: Adam
+ lr: 0.000267
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: MultiStepLR
+ milestones: [12]
+ gamma: 0.1
+
+Architecture:
+ model_type: rec
+ algorithm: ABINet
+ Transform:
+ Encoder:
+ name: ResNet45
+ in_channels: 3
+ strides: [2, 1, 2, 1, 1]
+ Decoder:
+ name: ABINetDecoder
+ iter_size: 3
+
+Loss:
+ name: ABINetLoss
+
+PostProcess:
+ name: ABINetLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ABINetLabelEncode:
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ABINetLabelEncode:
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml b/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b065c0ee773808b6a2f897ee89b3a0aa5d657f0e
--- /dev/null
+++ b/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml
@@ -0,0 +1,93 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet45_trans_abinet_wo_lang/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_resnet45_trans_abinet_wo_lang.txt
+ grad_clip_val: 20
+ use_amp: True
+
+Optimizer:
+ name: Adam
+ lr: 0.000267
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: MultiStepLR
+ milestones: [12]
+ gamma: 0.1
+
+Architecture:
+ model_type: rec
+ algorithm: ABINet
+ Transform:
+ Encoder:
+ name: ResNet45
+ in_channels: 3
+ strides: [2, 1, 2, 1, 1]
+ Decoder:
+ name: ABINetDecoder
+ iter_size: 0
+
+Loss:
+ name: ABINetLoss
+
+PostProcess:
+ name: ABINetLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ABINetLabelEncode:
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ABINetLabelEncode:
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/abinet/svtrv2_abinet_lang.yml b/configs/rec/abinet/svtrv2_abinet_lang.yml
new file mode 100644
index 0000000000000000000000000000000000000000..424b1a3ad7b344df47d3da7b87024a68554efbce
--- /dev/null
+++ b/configs/rec/abinet/svtrv2_abinet_lang.yml
@@ -0,0 +1,130 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_abinet_lang/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ # ./openocr_svtrv2_nolang_abinet_lang.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_abinet_lang.txt
+ use_amp: True
+ grad_clip_val: 20
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: ABINet
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: ABINetDecoder
+ iter_size: 3
+ num_layers: 0
+
+Loss:
+ name: ABINetLoss
+
+PostProcess:
+ name: ABINetLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ABINetLabelEncode:
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ABINetLabelEncode:
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/abinet/svtrv2_abinet_wo_lang.yml b/configs/rec/abinet/svtrv2_abinet_wo_lang.yml
new file mode 100644
index 0000000000000000000000000000000000000000..457571d2609c1c14c026eb5543368c64c09238a0
--- /dev/null
+++ b/configs/rec/abinet/svtrv2_abinet_wo_lang.yml
@@ -0,0 +1,128 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_abinet_wo_lang/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_abinet_wo_lang.txt
+ use_amp: True
+ grad_clip_val: 20
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: ABINet
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: ABINetDecoder
+ iter_size: 0
+ num_layers: 0
+Loss:
+ name: ABINetLoss
+
+PostProcess:
+ name: ABINetLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ABINetLabelEncode:
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ABINetLabelEncode:
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/aster/resnet31_lstm_aster_tps_on.yml b/configs/rec/aster/resnet31_lstm_aster_tps_on.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2df1135a04fd216844031cb00fd1a448c75ec7d2
--- /dev/null
+++ b/configs/rec/aster/resnet31_lstm_aster_tps_on.yml
@@ -0,0 +1,93 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet31_lstm_aster_tps_on
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_aster_tps.txt
+ use_amp: True
+ grad_clip_val: 1.0
+
+Optimizer:
+ name: Adam
+ lr: 0.002 # for 1gpus bs1024/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: aster
+ Transform:
+ name: Aster_TPS
+ tps_inputsize: [32, 64]
+ tps_outputsize: [32, 128]
+ Encoder:
+ name: ResNet_ASTER
+ Decoder:
+ name: ASTERDecoder
+
+Loss:
+ name: ARLoss
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+PostProcess:
+ name: ARLabelDecode
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 1024
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/aster/svtrv2_aster.yml b/configs/rec/aster/svtrv2_aster.yml
new file mode 100644
index 0000000000000000000000000000000000000000..21921cd5e50c56025f333e91297cbbc885230507
--- /dev/null
+++ b/configs/rec/aster/svtrv2_aster.yml
@@ -0,0 +1,127 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_aster
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_aster.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: aster
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ out_channels: 256
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: False
+ Decoder:
+ name: ASTERDecoder
+
+Loss:
+ name: ARLoss
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+PostProcess:
+ name: ARLabelDecode
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/aster/svtrv2_aster_tps_on.yml b/configs/rec/aster/svtrv2_aster_tps_on.yml
new file mode 100644
index 0000000000000000000000000000000000000000..34f940143dc32506008626187078d0056a3430c0
--- /dev/null
+++ b/configs/rec/aster/svtrv2_aster_tps_on.yml
@@ -0,0 +1,102 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_aster_tps_on
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_aster_tps_on.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: aster
+ Transform:
+ name: Aster_TPS
+ tps_inputsize: [32, 64]
+ tps_outputsize: [32, 128]
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ out_channels: 256
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: False
+ Decoder:
+ name: ASTERDecoder
+
+Loss:
+ name: ARLoss
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+PostProcess:
+ name: ARLabelDecode
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/autostr/autostr_lstm_aster_tps_on.yml b/configs/rec/autostr/autostr_lstm_aster_tps_on.yml
new file mode 100644
index 0000000000000000000000000000000000000000..11b2b4040bb12f50df16207d2ffe5e137c084f5d
--- /dev/null
+++ b/configs/rec/autostr/autostr_lstm_aster_tps_on.yml
@@ -0,0 +1,95 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/autostr_lstm_aster_tps_on
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_autostr_lstm_aster_tps_on.txt
+ use_amp: True
+ grad_clip_val: 1.0
+
+Optimizer:
+ name: Adam
+ lr: 0.002 # for 4gpus bs256/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: autostr
+ Transform:
+ name: Aster_TPS
+ tps_inputsize: [32, 64]
+ tps_outputsize: [32, 128]
+ Encoder:
+ name: AutoSTREncoder
+ stride_stages: '[(2, 2), (2, 1), (2, 2), (2, 1), (2, 1)]'
+ conv_op_ids: [2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 4, 1, 1, 6, 6]
+ Decoder:
+ name: ASTERDecoder
+
+Loss:
+ name: ARLoss
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+PostProcess:
+ name: ARLabelDecode
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/busnet/svtrv2_busnet.yml b/configs/rec/busnet/svtrv2_busnet.yml
new file mode 100644
index 0000000000000000000000000000000000000000..505acdcc24f0fe3fb7c189823787ea6111014b70
--- /dev/null
+++ b/configs/rec/busnet/svtrv2_busnet.yml
@@ -0,0 +1,135 @@
+Global:
+ device: gpu
+ epoch_num: 10
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_busnet/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ # ./output/rec/u14m_filter/svtrv2_busnet_pretraining/best.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_busnet.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: BUSBet
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: False
+ Decoder:
+ name: BUSDecoder
+ nhead: 6
+ num_layers: 6
+ dim_feedforward: 1536
+ ignore_index: &ignore_index 100
+ pretraining: False
+ # return_id: 2
+Loss:
+ name: ABINetLoss
+ ignore_index: *ignore_index
+
+PostProcess:
+ name: ABINetLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ABINetLabelEncode:
+ ignore_index: *ignore_index
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ABINetLabelEncode:
+ ignore_index: *ignore_index
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/busnet/svtrv2_busnet_pretraining.yml b/configs/rec/busnet/svtrv2_busnet_pretraining.yml
new file mode 100644
index 0000000000000000000000000000000000000000..53686d97f20eb00c053aba1f44459ed53939fa77
--- /dev/null
+++ b/configs/rec/busnet/svtrv2_busnet_pretraining.yml
@@ -0,0 +1,134 @@
+Global:
+ device: gpu
+ epoch_num: 10
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_busnet_pretraining/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_busnet_pretraining.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: BUSBet
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: False
+ Decoder:
+ name: BUSDecoder
+ nhead: 6
+ num_layers: 6
+ dim_feedforward: 1536
+ ignore_index: &ignore_index 100
+ pretraining: True
+ # return_id: 0
+Loss:
+ name: ABINetLoss
+ ignore_index: *ignore_index
+
+PostProcess:
+ name: ABINetLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ABINetLabelEncode:
+ ignore_index: *ignore_index
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ABINetLabelEncode:
+ ignore_index: *ignore_index
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/busnet/vit_busnet.yml b/configs/rec/busnet/vit_busnet.yml
new file mode 100644
index 0000000000000000000000000000000000000000..46a7c8aebc402a5e94ab209a316df1d3d0dc56a4
--- /dev/null
+++ b/configs/rec/busnet/vit_busnet.yml
@@ -0,0 +1,104 @@
+Global:
+ device: gpu
+ epoch_num: 10
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/vit_busnet/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_vit_busnet.txt
+ grad_clip_val: 20
+ use_amp: True
+
+Optimizer:
+ name: Adam
+ lr: 0.00053 # 4gpus bs256/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: MultiStepLR
+ milestones: [6]
+ gamma: 0.1
+
+Architecture:
+ model_type: rec
+ algorithm: BUSBet
+ Transform:
+ Encoder:
+ name: ViT
+ img_size: [32,128]
+ patch_size: [4, 8]
+ embed_dim: 384
+ depth: 12
+ num_heads: 6
+ mlp_ratio: 4
+ qkv_bias: True
+ Decoder:
+ name: BUSDecoder
+ nhead: 6
+ num_layers: 6
+ dim_feedforward: 1536
+ ignore_index: &ignore_index 100
+ pretraining: False
+Loss:
+ name: ABINetLoss
+ ignore_index: *ignore_index
+
+PostProcess:
+ name: ABINetLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ABINetLabelEncode:
+ ignore_index: *ignore_index
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ABINetLabelEncode:
+ ignore_index: *ignore_index
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/busnet/vit_busnet_pretraining.yml b/configs/rec/busnet/vit_busnet_pretraining.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f5955014482b49834b18288f7ab96da7f20c02d2
--- /dev/null
+++ b/configs/rec/busnet/vit_busnet_pretraining.yml
@@ -0,0 +1,104 @@
+Global:
+ device: gpu
+ epoch_num: 10
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/vit_busnet_pretraining/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_vit_busnet_pretraining.txt
+ grad_clip_val: 20
+ use_amp: True
+
+Optimizer:
+ name: Adam
+ lr: 0.00053 # 4gpus bs256/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: MultiStepLR
+ milestones: [6]
+ gamma: 0.1
+
+Architecture:
+ model_type: rec
+ algorithm: BUSBet
+ Transform:
+ Encoder:
+ name: ViT
+ img_size: [32,128]
+ patch_size: [4, 8]
+ embed_dim: 384
+ depth: 12
+ num_heads: 6
+ mlp_ratio: 4
+ qkv_bias: True
+ Decoder:
+ name: BUSDecoder
+ nhead: 6
+ num_layers: 6
+ dim_feedforward: 1536
+ ignore_index: &ignore_index 100
+ pretraining: True
+Loss:
+ name: ABINetLoss
+ ignore_index: *ignore_index
+
+PostProcess:
+ name: ABINetLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ABINetLabelEncode:
+ ignore_index: *ignore_index
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ABINetLabelEncode:
+ ignore_index: *ignore_index
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/cam/convnextv2_cam_tps_on.yml b/configs/rec/cam/convnextv2_cam_tps_on.yml
new file mode 100644
index 0000000000000000000000000000000000000000..bcaa857765a1cd65249ad6e54b36e7e9f66677f2
--- /dev/null
+++ b/configs/rec/cam/convnextv2_cam_tps_on.yml
@@ -0,0 +1,118 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/convnextv2_cam_tps_on
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_convnextv2_cam_tps_on.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.0008 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+ eps: 1.e-8
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 : 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: CAM
+ Transform:
+ name: Aster_TPS
+ tps_inputsize: [32, 64]
+ tps_outputsize: &img_shape [32, 128]
+ Encoder:
+ name: CAMEncoder
+ encoder_config:
+ name: ConvNeXtV2
+ depths: [2, 2, 8, 2]
+ dims: [80, 160, 320, 640]
+ strides: [[4,4], [2,1], [2,1], [1,1]]
+ drop_path_rate: 0.2
+ feat2d: True
+ nb_classes: 97
+ strides: [[4,4], [2,1], [2,1], [1,1]]
+ deform_stride: 2
+ stage_idx: 2
+ use_depthwise_unet: True
+ use_more_unet: False
+ binary_loss_type: BanlanceMultiClassCrossEntropyLoss
+ mid_size: True
+ d_embedding: 384
+ Decoder:
+ name: CAMDecoder
+ num_encoder_layers: -1
+ beam_size: 0
+ num_decoder_layers: 2
+ nhead: 8
+ max_len: *max_text_length
+
+Loss:
+ name: CAMLoss
+ loss_weight_binary: 1.5
+ label_smoothing: 0.
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+PostProcess:
+ name: ARLabelDecode
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CAMLabelEncode: # Class handling label
+ font_path: ./arial.ttf
+ image_shape: *img_shape
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'binary_mask'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml b/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3e4b18adecf0c093511cc80ccdf8cc6ca3d4449a
--- /dev/null
+++ b/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml
@@ -0,0 +1,118 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/convnextv2_tiny_cam_tps_on
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_convnextv2_cam_tps_on.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.0008 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+ eps: 1.e-8
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 : 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: CAM
+ Transform:
+ name: Aster_TPS
+ tps_inputsize: [32, 64]
+ tps_outputsize: &img_shape [32, 128]
+ Encoder:
+ name: CAMEncoder
+ encoder_config:
+ name: ConvNeXtV2
+ depths: [3, 3, 9, 3]
+ dims: [96, 192, 384, 768]
+ strides: [[4,4], [2,1], [2,1], [1,1]]
+ drop_path_rate: 0.2
+ feat2d: True
+ nb_classes: 97
+ strides: [[4,4], [2,1], [2,1], [1,1]]
+ deform_stride: 2
+ stage_idx: 2
+ use_depthwise_unet: True
+ use_more_unet: False
+ binary_loss_type: BanlanceMultiClassCrossEntropyLoss
+ mid_size: False
+ d_embedding: 512
+ Decoder:
+ name: CAMDecoder
+ num_encoder_layers: -1
+ beam_size: 0
+ num_decoder_layers: 2
+ nhead: 8
+ max_len: *max_text_length
+
+Loss:
+ name: CAMLoss
+ loss_weight_binary: 1.5
+ label_smoothing: 0.
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+PostProcess:
+ name: ARLabelDecode
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CAMLabelEncode: # Class handling label
+ font_path: ./arial.ttf
+ image_shape: *img_shape
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'binary_mask'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/cam/svtrv2_cam_tps_on.yml b/configs/rec/cam/svtrv2_cam_tps_on.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0c4504c0d7689135655a6bfbf1207ec468390aff
--- /dev/null
+++ b/configs/rec/cam/svtrv2_cam_tps_on.yml
@@ -0,0 +1,123 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_cam_tps_on
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_cam_tps_on.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 : 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: CAM
+ Transform:
+ name: Aster_TPS
+ tps_inputsize: [32, 64]
+ tps_outputsize: &img_shape [32, 128]
+ Encoder:
+ name: CAMEncoder
+ encoder_config:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ nb_classes: 97
+ strides: [[4, 4], [1, 1], [2, 1], [1, 1]]
+ k_size: [[2, 2], [1, 1], [2, 1], [1, 1]]
+ q_size: [4, 32]
+ deform_stride: 2
+ stage_idx: 2
+ use_depthwise_unet: True
+ use_more_unet: False
+ binary_loss_type: BanlanceMultiClassCrossEntropyLoss
+ mid_size: True
+ d_embedding: 384
+ Decoder:
+ name: CAMDecoder
+ num_encoder_layers: -1
+ beam_size: 0
+ num_decoder_layers: 2
+ nhead: 8
+ max_len: *max_text_length
+
+Loss:
+ name: CAMLoss
+ loss_weight_binary: 1.5
+ label_smoothing: 0.
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+PostProcess:
+ name: ARLabelDecode
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CAMLabelEncode: # Class handling label
+ font_path: ./arial.ttf
+ image_shape: *img_shape
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'binary_mask'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/cdistnet/resnet45_trans_cdistnet.yml b/configs/rec/cdistnet/resnet45_trans_cdistnet.yml
new file mode 100644
index 0000000000000000000000000000000000000000..938f3e97a8916984e894cf1a28c59aa5183e38c3
--- /dev/null
+++ b/configs/rec/cdistnet/resnet45_trans_cdistnet.yml
@@ -0,0 +1,93 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet45_trans_cdistnet
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_resnet45_trans_cdistnet.txt
+ use_amp: True
+ grad_clip_val: 5
+
+Optimizer:
+ name: Adam
+ lr: 0.002 # for 4gpus bs256/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: CDistNet
+ Transform:
+ Encoder:
+ name: ResNet45
+ in_channels: 3
+ strides: [2, 1, 2, 1, 1]
+ Decoder:
+ name: CDistNetDecoder
+ add_conv: True
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/cdistnet/svtrv2_cdistnet.yml b/configs/rec/cdistnet/svtrv2_cdistnet.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a6212ca9491424359e36a7de1b5d16fb19950f9e
--- /dev/null
+++ b/configs/rec/cdistnet/svtrv2_cdistnet.yml
@@ -0,0 +1,139 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_cdistnet/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_cdistnet.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 #4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: CDistNet
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ out_channels: 256
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: CDistNetDecoder
+ add_conv: False
+ num_encoder_blocks: 0
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/cppd/svtr_base_cppd.yml b/configs/rec/cppd/svtr_base_cppd.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f49548a7745503156927404eec88284c3e052201
--- /dev/null
+++ b/configs/rec/cppd/svtr_base_cppd.yml
@@ -0,0 +1,123 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtr_base_cppd/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path
+ # ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_base_cppd.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: CPPD
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRNet
+ img_size: [32, 128]
+ out_char_num: 25
+ out_channels: 256
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: ['Conv','Conv','Conv','Conv','Conv','Conv', 'Conv','Conv', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ last_stage: False
+ prenorm: True
+ Decoder:
+ name: CPPDDecoder
+ vis_seq: 64
+ num_layer: 2
+ pos_len: False
+ rec_layer: 1
+
+
+Loss:
+ name: CPPDLoss
+ ignore_index: 100
+ smoothing: True
+ pos_len: False
+ sideloss_weight: 1.0
+
+PostProcess:
+ name: CPPDLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CPPDLabelEncode: # Class handling label
+ pos_len: False
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_node', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CPPDLabelEncode: # Class handling label
+ pos_len: False
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 128
+ num_workers: 4
diff --git a/configs/rec/cppd/svtr_base_cppd_ch.yml b/configs/rec/cppd/svtr_base_cppd_ch.yml
new file mode 100644
index 0000000000000000000000000000000000000000..4476cfd77bafa0617c3b880618f47b4c5f0f4fbf
--- /dev/null
+++ b/configs/rec/cppd/svtr_base_cppd_ch.yml
@@ -0,0 +1,126 @@
+Global:
+ device: gpu
+ epoch_num: 100
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/ch/svtr_base_cppd/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 2000]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/ppocr_keys_v1.txt
+ # ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/ch/predicts_svtr_base_cppd.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.0005 # for 4gpus bs128/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: CosineAnnealingLR
+ warmup_epoch: 5
+
+Architecture:
+ model_type: rec
+ algorithm: CPPD
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRNet
+ img_size: [32, 256]
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 4]
+ num_heads: [4, 8, 12]
+ mixer: ['Conv','Conv','Conv','Conv','Conv','Conv', 'Conv','Conv', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ last_stage: False
+ prenorm: True
+ Decoder:
+ name: CPPDDecoder
+ vis_seq: 128
+ num_layer: 3
+ pos_len: False
+ rec_layer: 1
+ ch: True
+
+
+Loss:
+ name: CPPDLoss
+ ignore_index: 7000
+ smoothing: True
+ pos_len: False
+ sideloss_weight: 1.0
+
+PostProcess:
+ name: CPPDLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../benchmark_bctr/benchmark_bctr_train
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - CPPDLabelEncode: # Class handling label
+ pos_len: False
+ ch: True
+ ignore_index: 7000
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - SVTRResize:
+ image_shape: [3, 32, 256]
+ padding: True
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_node', 'label_index', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 8
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../benchmark_bctr/benchmark_bctr_test/scene_test
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - CPPDLabelEncode: # Class handling label
+ pos_len: False
+ ch: True
+ ignore_index: 7000
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - SVTRResize:
+ image_shape: [3, 32, 256]
+ padding: True
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_node', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 4
diff --git a/configs/rec/cppd/svtr_base_cppd_h8.yml b/configs/rec/cppd/svtr_base_cppd_h8.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8a5f71a63dfa6ca5e7a449fc5a6ac651231037c9
--- /dev/null
+++ b/configs/rec/cppd/svtr_base_cppd_h8.yml
@@ -0,0 +1,123 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtr_base_h8_cppd/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_base_cppd.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: CPPD
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRNet
+ img_size: [32, 128]
+ out_char_num: 25
+ out_channels: 256
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ sub_k: [[1, 1], [2, 1]]
+ mixer: ['Conv','Conv','Conv','Conv','Conv','Conv', 'Conv','Conv', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ last_stage: False
+ prenorm: True
+ Decoder:
+ name: CPPDDecoder
+ vis_seq: 128
+ num_layer: 2
+ pos_len: False
+ rec_layer: 1
+
+Loss:
+ name: CPPDLoss
+ ignore_index: 100
+ smoothing: True
+ pos_len: False
+ sideloss_weight: 1.0
+
+PostProcess:
+ name: CPPDLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CPPDLabelEncode: # Class handling label
+ pos_len: False
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_node', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CPPDLabelEncode: # Class handling label
+ pos_len: False
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 128
+ num_workers: 4
diff --git a/configs/rec/cppd/svtr_base_cppd_syn.yml b/configs/rec/cppd/svtr_base_cppd_syn.yml
new file mode 100644
index 0000000000000000000000000000000000000000..4a3fc9626999720c5085a0a0d072796f0cc143bb
--- /dev/null
+++ b/configs/rec/cppd/svtr_base_cppd_syn.yml
@@ -0,0 +1,124 @@
+Global:
+ device: gpu
+ epoch_num: 60
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/syn/svtr_base_cppd/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path
+ # ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/syn/predicts_svtr_base_cppd.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.0005 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: CosineAnnealingLR
+ warmup_epoch: 6
+
+Architecture:
+ model_type: rec
+ algorithm: CPPD
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRNet
+ img_size: [32, 100]
+ out_char_num: 25
+ out_channels: 256
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 4]
+ num_heads: [4, 8, 12]
+ mixer: ['Conv','Conv','Conv','Conv','Conv','Conv', 'Conv','Conv', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ last_stage: False
+ prenorm: True
+ Decoder:
+ name: CPPDDecoder
+ vis_seq: 50
+ num_layer: 3
+ pos_len: False
+ rec_layer: 1
+
+
+Loss:
+ name: CPPDLoss
+ ignore_index: 100
+ smoothing: True
+ pos_len: False
+ sideloss_weight: 1.0
+
+PostProcess:
+ name: CPPDLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: STRLMDBDataSet
+ data_dir: ./
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ # - SVTRRAug:
+ - CPPDLabelEncode: # Class handling label
+ pos_len: False
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - SVTRResize:
+ image_shape: [3, 32, 100]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_node', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 8
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - CPPDLabelEncode: # Class handling label
+ pos_len: False
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - SVTRResize:
+ image_shape: [3, 32, 100]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_node', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 4
diff --git a/configs/rec/cppd/svtrv2_cppd.yml b/configs/rec/cppd/svtrv2_cppd.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a369cfee6f7a90ae2d07e6ce351ca3468efd00fe
--- /dev/null
+++ b/configs/rec/cppd/svtrv2_cppd.yml
@@ -0,0 +1,150 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_cppd/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_cppd.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: CPPD
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ out_channels: 256
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: False
+ Decoder:
+ name: CPPDDecoder
+ ds: True
+ num_layer: 2
+ pos_len: False
+ rec_layer: 1
+
+
+Loss:
+ name: CPPDLoss
+ ignore_index: 100
+ smoothing: True
+ pos_len: False
+ sideloss_weight: 1.0
+
+PostProcess:
+ name: CPPDLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CPPDLabelEncode: # Class handling label
+ pos_len: False
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_node', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CPPDLabelEncode: # Class handling label
+ pos_len: False
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/dan/resnet45_fpn_dan.yml b/configs/rec/dan/resnet45_fpn_dan.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2c1b55a9988c9761d8e56af0072346e4644decd9
--- /dev/null
+++ b/configs/rec/dan/resnet45_fpn_dan.yml
@@ -0,0 +1,98 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet45_fpn_dan/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_resnet45_fpn_dan.txt
+ use_amp: True
+ grad_clip_val: 20
+
+Optimizer:
+ name: Adam
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: DAN
+ Transform:
+ Encoder:
+ name: ResNet45
+ in_channels: 3
+ strides: [2, 1, 2, 1, 1]
+ return_list: True
+ Decoder:
+ name: DANDecoder
+ max_len: 25
+ channels_list: [64, 128, 256, 512]
+ strides_list: [[2, 2], [1, 1], [1, 1]]
+ in_shape: [8, 32]
+ depth: 4
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode:
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode:
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/dan/svtrv2_dan.yml b/configs/rec/dan/svtrv2_dan.yml
new file mode 100644
index 0000000000000000000000000000000000000000..4d877e61e0d95b3b6eca3556f7e5a88295e41980
--- /dev/null
+++ b/configs/rec/dan/svtrv2_dan.yml
@@ -0,0 +1,130 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_dan
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_dan.txt
+ use_amp: True
+ grad_clip_val: 20
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # 4gpus 256bs/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: DAN
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ out_channels: 256
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: DANDecoder
+ use_cam: False
+ max_len: 25
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode:
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode:
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/focalsvtr/focalsvtr_ctc.yml b/configs/rec/focalsvtr/focalsvtr_ctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..13616620df5fc6d3121da547a7e2309be0270355
--- /dev/null
+++ b/configs/rec/focalsvtr/focalsvtr_ctc.yml
@@ -0,0 +1,137 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/focalsvtr_ctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path
+ # ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_focalsvtr_ctc.txt
+
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: FocalSVTR
+ img_size: [32, 128]
+ depths: [6, 6, 6]
+ embed_dim: 96
+ sub_k: [[1, 1], [2, 1], [1, 1]]
+ focal_levels: [3, 3, 3]
+ out_channels: 256
+ last_stage: True
+ Decoder:
+ name: CTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: &padding False
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - PARSeqAug:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: 12
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: True
+ data_dir_list: ['../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 128
+ max_ratio: 12
+ num_workers: 4
diff --git a/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml b/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..57e97fc8d283b9215970ee57d77b9786a8f40453
--- /dev/null
+++ b/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml
@@ -0,0 +1,168 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/svtrv2_lnconv_nrtr_gtc
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img: ../ltb/img
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/predicts_smtr.txt
+ use_amp: True
+ distributed: true
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: BGPD
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ out_channels: 256
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: GTCDecoder
+ infer_gtc: True
+ detach: False
+ gtc_decoder:
+ name: NRTRDecoder
+ num_encoder_layers: -1
+ beam_size: 0
+ num_decoder_layers: 2
+ nhead: 12
+ max_len: *max_text_length
+ ctc_decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: GTCLoss
+ gtc_loss:
+ name: ARLoss
+
+PostProcess:
+ name: GTCLabelDecode
+ gtc_label_decode:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecGTCMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ # max_ratio: &max_ratio 4
+ # min_ratio: 1
+ # base_shape: &base_shape [[64, 64], [96, 48], [112, 40], [128, 32]]
+ # base_h: &base_h 32
+ # padding: &padding False
+ padding: false
+ # padding_rand: true
+ # padding_doub: true
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - PARSeqAug:
+ - GTCLabelEncode: # Class handling label
+ gtc_label_encode:
+ name: ARLabelEncode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'ctc_label', 'ctc_length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - GTCLabelEncode: # Class handling label
+ gtc_label_encode:
+ name: ARLabelEncode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'ctc_label', 'ctc_length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml b/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml
new file mode 100644
index 0000000000000000000000000000000000000000..1b213620839d1c041d3ebc2d9677ecadd2bc8e90
--- /dev/null
+++ b/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml
@@ -0,0 +1,151 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/svtrv2_lnconv_smtr_gtc_long_infer
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 1000]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img: ../ltb/img
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/predicts_smtr.txt
+ use_amp: True
+ distributed: true
+
+Optimizer:
+ name: AdamW
+ lr: 0.000325
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: BGPD
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ out_channels: 256
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: GTCDecoder
+ infer_gtc: False
+ detach: False
+ gtc_decoder:
+ name: SMTRDecoder
+ num_layer: 1
+ ds: True
+ max_len: *max_text_length
+ next_mode: &next True
+ sub_str_len: &subsl 5
+ ctc_decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: CTCLoss
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 12
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml b/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml
new file mode 100644
index 0000000000000000000000000000000000000000..66f557ba451ee7e05dbba33f37c521a5efaa43ee
--- /dev/null
+++ b/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml
@@ -0,0 +1,150 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/svtrv2_lnconv_smtr_gtc_nodetach_smtr_long_infer
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 1000]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/predicts_smtr.txt
+ use_amp: True
+ distributed: true
+
+Optimizer:
+ name: AdamW
+ lr: 0.000325
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+Architecture:
+ model_type: rec
+ algorithm: BGPD
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ out_channels: 256
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: GTCDecoder
+ infer_gtc: True
+ detach: False
+ gtc_decoder:
+ name: SMTRDecoder
+ num_layer: 1
+ ds: True
+ max_len: *max_text_length
+ next_mode: &next True
+ sub_str_len: &subsl 5
+ infer_aug: True
+ ctc_decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: GTCLoss
+ ctc_weight: 0.1
+ gtc_loss:
+ name: SMTRLoss
+
+PostProcess:
+ name: GTCLabelDecode
+ gtc_label_decode:
+ name: SMTRLabelDecode
+ next_mode: *next
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ only_gtc: True
+
+Metric:
+ name: RecGTCMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - SMTRLabelEncode: # Class handling label
+ sub_str_len: *subsl
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_subs', 'label_next', 'length_subs',
+ 'label_subs_pre', 'label_next_pre', 'length_subs_pre', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 12
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ../ltb/
+ label_file_list: ['../ltb/ultra_long_70_list.txt']
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - GTCLabelEncode: # Class handling label
+ gtc_label_encode:
+ name: ARLabelEncode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: 200
+ - SliceResize:
+ image_shape: [3, 32, 128]
+ padding: False
+ max_ratio: 12
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'ctc_label', 'ctc_length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1
+ num_workers: 2
diff --git a/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml b/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d00af789423abbc52cc192542d57d44a71b59129
--- /dev/null
+++ b/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml
@@ -0,0 +1,152 @@
+Global:
+ device: gpu
+ epoch_num: 60
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/svtrv2_lnconv_smtr_gtc_stream
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/predicts_smtr.txt
+ use_amp: True
+ distributed: true
+ grad_clip_val: 20
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: BGPD
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ out_channels: 256
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: GTCDecoder
+ infer_gtc: True
+ detach: False
+ gtc_decoder:
+ name: SMTRDecoder
+ num_layer: 1
+ ds: True
+ max_len: *max_text_length
+ next_mode: &next True
+ sub_str_len: &subsl 5
+ infer_aug: False
+ ctc_decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: GTCLoss
+ ctc_weight: 0.25
+ gtc_loss:
+ name: SMTRLoss
+
+PostProcess:
+ name: GTCLabelDecode
+ gtc_label_decode:
+ name: SMTRLabelDecode
+ next_mode: *next
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ only_gtc: True
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+ stream: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - SMTRLabelEncode: # Class handling label
+ sub_str_len: *subsl
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_subs', 'label_next', 'length_subs',
+ 'label_subs_pre', 'label_next_pre', 'length_subs_pre', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 12
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ../ltb/
+ label_file_list: ['../ltb/ultra_long_70_list.txt']
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - GTCLabelEncode: # Class handling label
+ gtc_label_encode:
+ name: ARLabelEncode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - SliceTVResize:
+ image_shape: [32, 128]
+ padding: False
+ max_ratio: 4
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'ctc_label', 'ctc_length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1
+ num_workers: 2
diff --git a/configs/rec/igtr/readme.md b/configs/rec/igtr/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..61c7d337be7d8d573c0aa2b1ddf86b46c0c7d7db
--- /dev/null
+++ b/configs/rec/igtr/readme.md
@@ -0,0 +1,189 @@
+# IGTR
+
+- [IGTR](#igtr)
+ - [1. Introduction](#1-introduction)
+ - [2. Environment](#2-environment)
+ - [Dataset Preparation](#dataset-preparation)
+ - [3. Model Training / Evaluation](#3-model-training--evaluation)
+ - [Citation](#citation)
+
+
+
+## 1. Introduction
+
+Paper:
+
+> [Instruction-Guided Scene Text Recognition](https://arxiv.org/abs/2401.17851)
+> Yongkun Du, Zhineng Chen, Yuchen Su, Caiyan Jia, Yu-Gang Jiang
+
+
+Multi-modal models show appealing performance in visual recognition tasks recently, as free-form text-guided training evokes the ability to understand fine-grained visual content. However, current models are either inefficient or cannot be trivially upgraded to scene text recognition (STR) due to the composition difference between natural and text images. We propose a novel instruction-guided scene text recognition (IGTR) paradigm that formulates STR as an instruction learning problem and understands text images by predicting character attributes, e.g., character frequency, position, etc. IGTR first devises $\\left \\langle condition,question,answer\\right \\rangle$ instruction triplets, providing rich and diverse descriptions of character attributes. To effectively learn these attributes through question-answering, IGTR develops lightweight instruction encoder, cross-modal feature fusion module and multi-task answer head, which guides nuanced text image understanding. Furthermore, IGTR realizes different recognition pipelines simply by using different instructions, enabling a character-understanding-based text reasoning paradigm that considerably differs from current methods. Experiments on English and Chinese benchmarks show that IGTR outperforms existing models by significant margins, while maintaining a small model size and efficient inference speed. Moreover, by adjusting the sampling of instructions, IGTR offers an elegant way to tackle the recognition of both rarely appearing and morphologically similar characters, which were previous challenges.
+
+
+The accuracy (%) and model files of IGTR on the public dataset of scene text recognition are as follows:
+
+- Trained on Synth dataset(MJ+ST), test on Common Benchmarks, training and test datasets both from [PARSeq](https://github.com/baudm/parseq).
+
+| Model | IC13
857 | SVT | IIIT5k
3000 | IC15
1811 | SVTP | CUTE80 | Avg | Config&Model&Log |
+| :-----: | :----------: | :--: | :-------------: | :-----------: | :--: | :----: | :---: | :---------------------------------------------------------------------------------------------: |
+| IGTR-PD | 97.6 | 95.2 | 97.6 | 88.4 | 91.6 | 95.5 | 94.30 | [link](https://drive.google.com/drive/folders/1Pv0CW2hiWC_dIyaB74W1fsXqiX3z5yXA?usp=drive_link) |
+| IGTR-AR | 98.6 | 95.7 | 98.2 | 88.4 | 92.4 | 95.5 | 94.78 | as above |
+
+- Test on Union14M-L benchmark, from [Union14M](https://github.com/Mountchicken/Union14M/).
+
+| Model | Curve | Multi-
Oriented | Artistic | Contextless | Salient | Multi-
word | General | Avg | Config&Model&Log |
+| :-----: | :---: | :-----------------: | :------: | :---------: | :-----: | :-------------: | :-----: | :---: | :---------------------: |
+| IGTR-PD | 76.9 | 30.6 | 59.1 | 63.3 | 77.8 | 62.5 | 66.7 | 62.40 | Same as the above table |
+| IGTR-AR | 78.4 | 31.9 | 61.3 | 66.5 | 80.2 | 69.3 | 67.9 | 65.07 | as above |
+
+- Trained on Union14M-L training dataset.
+
+| Model | IC13
857 | SVT | IIIT5k
3000 | IC15
1811 | SVTP | CUTE80 | Avg | Config&Model&Log |
+| :----------: | :----------: | :--: | :-------------: | :-----------: | :--: | :----: | :---: | :---------------------------------------------------------------------------------------------: |
+| IGTR-PD | 97.7 | 97.7 | 98.3 | 89.8 | 93.7 | 97.9 | 95.86 | [link](https://drive.google.com/drive/folders/1ZGlzDqEzjrBg8qG2wBkbOm3bLRzFbTzo?usp=drive_link) |
+| IGTR-AR | 98.1 | 98.4 | 98.7 | 90.5 | 94.9 | 98.3 | 96.48 | as above |
+| IGTR-PD-60ep | 97.9 | 98.3 | 99.2 | 90.8 | 93.7 | 97.6 | 96.24 | [link](https://drive.google.com/drive/folders/1ik4hxZDRsjU1RbCA19nwE45Kg1bCnMoa?usp=drive_link) |
+| IGTR-AR-60ep | 98.4 | 98.1 | 99.3 | 91.5 | 94.3 | 97.6 | 96.54 | as above |
+| IGTR-PD-PT | 98.6 | 98.0 | 99.1 | 91.7 | 96.8 | 99.0 | 97.20 | [link](https://drive.google.com/drive/folders/1QM0EWV66IfYI1G0Xm066V2zJA62hH6-1?usp=drive_link) |
+| IGTR-AR-PT | 98.8 | 98.3 | 99.2 | 92.0 | 96.8 | 99.0 | 97.34 | as above |
+
+| Model | Curve | Multi-
Oriented | Artistic | Contextless | Salient | Multi-
word | General | Avg | Config&Model&Log |
+| :----------: | :---: | :-----------------: | :------: | :---------: | :-----: | :-------------: | :-----: | :---: | :---------------------: |
+| IGTR-PD | 88.1 | 89.9 | 74.2 | 80.3 | 82.8 | 79.2 | 83.0 | 82.51 | Same as the above table |
+| IGTR-AR | 90.4 | 91.2 | 77.0 | 82.4 | 84.7 | 84.0 | 84.4 | 84.86 | as above |
+| IGTR-PD-60ep | 90.0 | 92.1 | 77.5 | 82.8 | 86.0 | 83.0 | 84.8 | 85.18 | Same as the above table |
+| IGTR-AR-60ep | 91.0 | 93.0 | 78.7 | 84.6 | 87.3 | 84.8 | 85.6 | 86.43 | as above |
+| IGTR-PD-PT | 92.4 | 92.1 | 80.7 | 83.6 | 87.7 | 86.9 | 85.0 | 86.92 | Same as the above table |
+| IGTR-AR-PT | 93.0 | 92.9 | 81.3 | 83.4 | 88.6 | 88.7 | 85.6 | 87.65 | as above |
+
+- Trained and test on Chinese dataset, from [Chinese Benckmark](https://github.com/FudanVI/benchmarking-chinese-text-recognition).
+
+| Model | Scene | Web | Document | Handwriting | Avg | Config&Model&Log |
+| :---------: | :---: | :--: | :------: | :---------: | :---: | :---------------------------------------------------------------------------------------------: |
+| IGTR-PD | 73.1 | 74.8 | 98.6 | 52.5 | 74.75 | |
+| IGTR-AR | 75.1 | 76.4 | 98.7 | 55.3 | 76.37 | |
+| IGTR-PD-TS | 73.5 | 75.9 | 98.7 | 54.5 | 75.65 | [link](https://drive.google.com/drive/folders/1H3VRdGHjhawd6fkSC-qlBzVzvYYTpHRg?usp=drive_link) |
+| IGTR-AR-TS | 75.6 | 77.0 | 98.8 | 57.3 | 77.17 | as above |
+| IGTR-PD-Aug | 79.5 | 80.0 | 99.4 | 58.9 | 79.45 | [link](https://drive.google.com/drive/folders/1XFQkCILwcFwA7iYyQY9crnrouaI5sqcZ?usp=drive_link) |
+| IGTR-AR-Aug | 82.0 | 81.7 | 99.5 | 63.8 | 81.74 | as above |
+
+Download all Configs, Models, and Logs from [Google Drive](https://drive.google.com/drive/folders/1mSRDg9Mj5R6PspAdFGXZHDHTCQmjkd8d?usp=drive_link).
+
+
+
+## 2. Environment
+
+- [PyTorch](http://pytorch.org/) version >= 1.13.0
+- Python version >= 3.7
+
+```shell
+git clone -b develop https://github.com/Topdu/OpenOCR.git
+cd OpenOCR
+# A100 Ubuntu 20.04 Cuda 11.8
+conda create -n openocr python==3.8
+conda activate openocr
+conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia
+pip install -r requirements.txt
+```
+
+#### Dataset Preparation
+
+[English dataset download](https://github.com/baudm/parseq)
+
+[Union14M-L download](https://github.com/Mountchicken/Union14M)
+
+[Chinese dataset download](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download)
+
+The expected filesystem structure is as follows:
+
+```
+benchmark_bctr
+├── benchmark_bctr_test
+│ ├── document_test
+│ ├── handwriting_test
+│ ├── scene_test
+│ └── web_test
+└── benchmark_bctr_train
+ ├── document_train
+ ├── handwriting_train
+ ├── scene_train
+ └── web_train
+evaluation
+├── CUTE80
+├── IC13_857
+├── IC15_1811
+├── IIIT5k
+├── SVT
+└── SVTP
+OpenOCR
+synth
+├── MJ
+│ ├── test
+│ ├── train
+│ └── val
+└── ST
+test # from PARSeq
+├── ArT
+├── COCOv1.4
+├── CUTE80
+├── IC13_1015
+├── IC13_1095
+├── IC13_857
+├── IC15_1811
+├── IC15_2077
+├── IIIT5k
+├── SVT
+├── SVTP
+└── Uber
+u14m # lmdb format
+├── artistic
+├── contextless
+├── curve
+├── general
+├── multi_oriented
+├── multi_words
+└── salient
+Union14M-LMDB-L # lmdb format
+├── train_challenging
+├── train_easy
+├── train_hard
+├── train_medium
+└── train_normal
+```
+
+
+
+## 3. Model Training / Evaluation
+
+Training:
+
+```shell
+# The configuration file is available from the link provided in the table above.
+# Multi GPU training
+CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 tools/train_rec.py --c PATH/svtr_base_igtr_XXX.yml
+```
+
+Evaluation:
+
+```shell
+# The configuration file is available from the link provided in the table above.
+# en
+python tools/eval_rec_all_ratio.py --c PATH/svtr_base_igtr_syn.yml
+# ch
+python tools/eval_rec_all_ch.py --c PATH/svtr_base_igtr_ch_aug.yml
+```
+
+## Citation
+
+```bibtex
+@article{Du2024IGTR,
+ title = {Instruction-Guided Scene Text Recognition},
+ author = {Du, Yongkun and Chen, Zhineng and Su, Yuchen and Jia, Caiyan and Jiang, Yu-Gang},
+ journal = {CoRR},
+ eprinttype = {arXiv},
+ primaryClass={cs.CV},
+ volume = {abs/2401.17851},
+ year = {2024},
+ url = {https://arxiv.org/abs/2401.17851}
+}
+```
diff --git a/configs/rec/igtr/svtr_base_ds_igtr.yml b/configs/rec/igtr/svtr_base_ds_igtr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..df29c39d3a28dbb5684fc2d6c94af65470d547ab
--- /dev/null
+++ b/configs/rec/igtr/svtr_base_ds_igtr.yml
@@ -0,0 +1,157 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtr_base_igtr
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path
+ # ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_base_igtr.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.0005 # 2gpus 384bs/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: IGTR
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRNet2DPos
+ img_size: [32, -1]
+ out_char_num: 25
+ out_channels: 256
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: ['ConvB','ConvB','ConvB','ConvB','ConvB','ConvB', 'ConvB','ConvB', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ last_stage: False
+ prenorm: True
+ use_first_sub: False
+ Decoder:
+ name: IGTRDecoder
+ dim: 384
+ num_layer: 1
+ ar: False
+ refine_iter: 0
+ # next_pred: True
+ next_pred: False
+ pos2d: True
+ ds: True
+ # pos_len: False
+ # rec_layer: 1
+
+
+Loss:
+ name: IGTRLoss
+
+PostProcess:
+ name: IGTRLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: &padding False
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - PARSeqAug:
+ - IGTRLabelEncode: # Class handling label
+ k: 8
+ prompt_error: False
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'prompt_pos_idx_list',
+ 'prompt_char_idx_list', 'ques_pos_idx_list', 'ques1_answer_list',
+ 'ques2_char_idx_list', 'ques2_answer_list', 'ques3_answer', 'ques4_char_num_list',
+ 'ques_len_list', 'ques2_len_list', 'prompt_len_list', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 384
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: *padding
+ data_dir_list: ['../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP']
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml b/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml
new file mode 100644
index 0000000000000000000000000000000000000000..745c585128ddf28b71b293d575ebc33e8fdfb592
--- /dev/null
+++ b/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml
@@ -0,0 +1,133 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/focalsvtr_lister_wo_fem_maxratio12/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_focalsvtr_lister_wo_fem_maxratio12.txt
+ use_amp: True
+ grad_clip_val: 20
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: LISTER
+ Transform:
+ Encoder:
+ name: FocalSVTR
+ img_size: [32, 128]
+ depths: [6, 6, 9]
+ embed_dim: 96
+ sub_k: [[1, 1], [2, 1], [1, 1]]
+ focal_levels: [3, 3, 3]
+ last_stage: False
+ feat2d: True
+ Decoder:
+ name: LISTERDecoder
+ detach_grad: False
+ attn_scaling: True
+ use_fem: False
+
+Loss:
+ name: LISTERLoss
+
+PostProcess:
+ name: LISTERLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - EPLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: 12
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: ['../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - EPLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: 12
+ num_workers: 4
diff --git a/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml b/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c24b272a2afbc7afbbcfd66dc722939735495be9
--- /dev/null
+++ b/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml
@@ -0,0 +1,138 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_lister_wo_fem_maxratio12/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_lister_wo_fem_maxratio12.txt
+ use_amp: True
+ grad_clip_val: 20
+
+Optimizer:
+ name: AdamW
+ lr: 0.000325
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: LISTER
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ out_channels: 256
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: LISTERDecoder
+ detach_grad: False
+ attn_scaling: True
+ use_fem: False
+
+Loss:
+ name: LISTERLoss
+
+PostProcess:
+ name: LISTERLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - EPLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: 12
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: ['../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - EPLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: 12
+ num_workers: 4
diff --git a/configs/rec/lpv/svtr_base_lpv.yml b/configs/rec/lpv/svtr_base_lpv.yml
new file mode 100644
index 0000000000000000000000000000000000000000..01c5509750407f47e306430ded119d94fca284d0
--- /dev/null
+++ b/configs/rec/lpv/svtr_base_lpv.yml
@@ -0,0 +1,124 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtr_base_lpv/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ # ./output/rec/u14m_filter/svtr_base_lpv_wo_glrm/best.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_lpv.txt
+ use_amp: True
+ grad_clip_val: 20
+
+Optimizer:
+ name: Adam
+ lr: 0.0001 # for 4gpus bs128/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+ betas: [0.9, 0.99]
+
+LRScheduler:
+ name: MultiStepLR
+ milestones: [12]
+ gamma: 0.1
+
+Architecture:
+ model_type: rec
+ algorithm: LPV
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRNet
+ img_size: [32, 128]
+ out_char_num: 25
+ out_channels: 256
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: ['Conv','Conv','Conv','Conv','Conv','Conv', 'Conv','Conv', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ sub_k: [[1, 1], [1, 1]]
+ feature2d: True
+ last_stage: False
+ prenorm: True
+ Decoder:
+ name: LPVDecoder
+ num_layer: 3
+ max_len: *max_text_length
+ use_mask: True
+ dim_feedforward: 1536
+ nhead: 12
+ dropout: 0.1
+ trans_layer: 3
+
+Loss:
+ name: LPVLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 128
+ num_workers: 4
diff --git a/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml b/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ede4bdf1f1b7a28f4735613d8dc6defd631c05a7
--- /dev/null
+++ b/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml
@@ -0,0 +1,123 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtr_base_lpv_wo_glrm/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_base_lpv_wo_glrm.txt
+ use_amp: True
+ grad_clip_val: 20
+
+Optimizer:
+ name: Adam
+ lr: 0.0001 # for 4gpus bs128/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+ betas: [0.9, 0.99]
+
+LRScheduler:
+ name: MultiStepLR
+ milestones: [12]
+ gamma: 0.1
+
+Architecture:
+ model_type: rec
+ algorithm: LPV
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRNet
+ img_size: [32, 128]
+ out_char_num: 25
+ out_channels: 256
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: ['Conv','Conv','Conv','Conv','Conv','Conv', 'Conv','Conv', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ sub_k: [[1, 1], [1, 1]]
+ feature2d: True
+ last_stage: False
+ prenorm: True
+ Decoder:
+ name: LPVDecoder
+ num_layer: 3
+ max_len: *max_text_length
+ use_mask: False
+ dim_feedforward: 1536
+ nhead: 12
+ dropout: 0.1
+ trans_layer: 3
+
+Loss:
+ name: LPVLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 128
+ num_workers: 4
diff --git a/configs/rec/lpv/svtrv2_lpv.yml b/configs/rec/lpv/svtrv2_lpv.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8b61431b9fd003a313389c11c2931cb09f730151
--- /dev/null
+++ b/configs/rec/lpv/svtrv2_lpv.yml
@@ -0,0 +1,147 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_lpv/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ # ./output/rec/u14m_filter/svtrv2_lpv_wo_glrm/best.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_lpv.txt
+ use_amp: True
+ grad_clip_val: 20
+
+Optimizer:
+ name: AdamW
+ lr: 0.000325 # for 4gpus bs128/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: LPV
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: LPVDecoder
+ num_layer: 3
+ max_len: *max_text_length
+ use_mask: True
+ dim_feedforward: 1536
+ nhead: 12
+ dropout: 0.1
+ trans_layer: 3
+
+Loss:
+ name: LPVLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml b/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml
new file mode 100644
index 0000000000000000000000000000000000000000..607c85ae3affb1fb28aff507bc0bdc0722016296
--- /dev/null
+++ b/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml
@@ -0,0 +1,146 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_lpv_wo_glrm/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_lpv_wo_glrm.txt
+ use_amp: True
+ grad_clip_val: 20
+
+Optimizer:
+ name: AdamW
+ lr: 0.000325 # for 4gpus bs128/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: LPV
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: LPVDecoder
+ num_layer: 3
+ max_len: *max_text_length
+ use_mask: False
+ dim_feedforward: 1536
+ nhead: 12
+ dropout: 0.1
+ trans_layer: 3
+
+Loss:
+ name: LPVLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/maerec/vit_nrtr.yml b/configs/rec/maerec/vit_nrtr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..837b5abba69f67ad34fa2c7d2f3894fdb25e2338
--- /dev/null
+++ b/configs/rec/maerec/vit_nrtr.yml
@@ -0,0 +1,116 @@
+Global:
+ device: gpu
+ epoch_num: 10
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/vit_nrtr_ft_mae/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ # ./open_ocr_vit_small_params.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_vit_nrtr_ft_mae.txt
+ use_amp: True
+ project_name: maerec
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 : 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: BGPD
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: ViT
+ img_size: [32, 128]
+ patch_size: [4, 4]
+ embed_dim: 384
+ depth: 12
+ num_heads: 6
+ mlp_ratio: 4
+ qkv_bias: True
+ use_cls_token: True
+ Decoder:
+ name: NRTRDecoder
+ num_encoder_layers: -1
+ beam_size: 0
+ num_decoder_layers: 6
+ nhead: 8
+ max_len: *max_text_length
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 4
diff --git a/configs/rec/matrn/resnet45_trans_matrn.yml b/configs/rec/matrn/resnet45_trans_matrn.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b2a5a71e7a599589b8235b9dc5e325bee92ba9c4
--- /dev/null
+++ b/configs/rec/matrn/resnet45_trans_matrn.yml
@@ -0,0 +1,95 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet45_trans_matrn/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ # ./openocr_nolang_abinet_lang.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_resnet45_trans_matrn.txt
+ grad_clip_val: 20
+ use_amp: True
+
+Optimizer:
+ name: Adam
+ lr: 0.000133 # 4gpus 128bs/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: MultiStepLR
+ milestones: [12, 18]
+ gamma: 0.1
+
+Architecture:
+ model_type: rec
+ algorithm: MATRN
+ Transform:
+ Encoder:
+ name: ResNet45
+ in_channels: 3
+ strides: [2, 1, 2, 1, 1]
+ Decoder:
+ name: MATRNDecoder
+ iter_size: 3
+
+Loss:
+ name: ABINetLoss
+ align_weight: 3.0
+
+PostProcess:
+ name: ABINetLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ABINetLabelEncode:
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ABINetLabelEncode:
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/matrn/svtrv2_matrn.yml b/configs/rec/matrn/svtrv2_matrn.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d7e7796f004fc7275128e0aa4cfeced3ca988669
--- /dev/null
+++ b/configs/rec/matrn/svtrv2_matrn.yml
@@ -0,0 +1,130 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_matrn/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ # ./openocr_svtrv2_nolang_abinet_lang.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_matrn.txt
+ use_amp: True
+ grad_clip_val: 20
+
+Optimizer:
+ name: AdamW
+ lr: 0.000325 # for 4gpus bs128/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: MATRN
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: MATRNDecoder
+ iter_size: 3
+ num_layers: 0
+
+Loss:
+ name: ABINetLoss
+
+PostProcess:
+ name: ABINetLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ABINetLabelEncode:
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ABINetLabelEncode:
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml b/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml
new file mode 100644
index 0000000000000000000000000000000000000000..dfb0265f609b21629f95afc3e6103c104b890fd7
--- /dev/null
+++ b/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml
@@ -0,0 +1,140 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_mgpstr_only_char/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ use_amp: True
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_mgpstr_only_char.txt
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # 4gpus 256bs/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: MGPSTR
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ out_channels: 256
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: false
+ Decoder:
+ name: MGPDecoder
+ only_char: &only_char True
+
+Loss:
+ name: MGPLoss
+ only_char: *only_char
+
+PostProcess:
+ name: MPGLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ only_char: *only_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - MGPLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ only_char: *only_char
+ - KeepKeys:
+ keep_keys: ['image', 'char_label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - MGPLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ only_char: *only_char
+ - KeepKeys:
+ keep_keys: ['image', 'char_label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml b/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml
new file mode 100644
index 0000000000000000000000000000000000000000..fe9d8243ab9d15ed5c44b3dbbeea1833b9624f8e
--- /dev/null
+++ b/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml
@@ -0,0 +1,111 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/vit_base_mgpstr/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ use_amp: True
+ save_res_path: ./output/rec/u14m_filter/predicts_vit_mgpstr_only_char.txt
+ grad_clip_val: 5
+ project_name: mgpstr_base
+
+Optimizer:
+ name: Adam
+ lr: 0.000325 # 4gpus 128bs/gpu
+ weight_decay: 0.
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: MGPSTR
+ Transform:
+ Encoder:
+ name: ViT
+ img_size: [32,128]
+ patch_size: [4, 4]
+ embed_dim: 768
+ depth: 12
+ num_heads: 12
+ mlp_ratio: 4
+ qkv_bias: True
+ Decoder:
+ name: MGPDecoder
+ only_char: &only_char True
+
+Loss:
+ name: MGPLoss
+ only_char: *only_char
+
+PostProcess:
+ name: MPGLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ only_char: *only_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - MGPLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ only_char: *only_char
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'char_label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - MGPLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ only_char: *only_char
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'char_label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml b/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml
new file mode 100644
index 0000000000000000000000000000000000000000..974ccaf350c9fd9c5aa98e6c877937929129816f
--- /dev/null
+++ b/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml
@@ -0,0 +1,110 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/vit_base_mgpstr_only_char/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ use_amp: True
+ save_res_path: ./output/rec/u14m_filter/predicts_vit_mgpstr_only_char.txt
+ grad_clip_val: 5
+
+Optimizer:
+ name: Adam
+ lr: 0.000325 # 4gpus 128bs/gpu
+ weight_decay: 0.
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: MGPSTR
+ Transform:
+ Encoder:
+ name: ViT
+ img_size: [32,128]
+ patch_size: [4, 4]
+ embed_dim: 1024
+ depth: 24
+ num_heads: 16
+ mlp_ratio: 4
+ qkv_bias: True
+ Decoder:
+ name: MGPDecoder
+ only_char: &only_char True
+
+Loss:
+ name: MGPLoss
+ only_char: *only_char
+
+PostProcess:
+ name: MPGLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ only_char: *only_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - MGPLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ only_char: *only_char
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'char_label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - MGPLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ only_char: *only_char
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'char_label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/mgpstr/vit_mgpstr.yml b/configs/rec/mgpstr/vit_mgpstr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b24139c85068578f53ad793f4ce0ade0b80244d4
--- /dev/null
+++ b/configs/rec/mgpstr/vit_mgpstr.yml
@@ -0,0 +1,110 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/vit_mgpstr/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [100000, 2000]
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ use_amp: True
+ save_res_path: ./output/rec/u14m_filter/predicts_vit_mgpstr.txt
+ grad_clip_val: 5
+
+Optimizer:
+ name: Adam
+ lr: 0.000325 # 4gpus 128bs/gpu
+ weight_decay: 0.
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: MGPSTR
+ Transform:
+ Encoder:
+ name: ViT
+ img_size: [32,128]
+ patch_size: [4, 4]
+ embed_dim: 384
+ depth: 12
+ num_heads: 6
+ mlp_ratio: 4
+ qkv_bias: True
+ Decoder:
+ name: MGPDecoder
+ only_char: &only_char False
+
+Loss:
+ name: MGPLoss
+ only_char: *only_char
+
+PostProcess:
+ name: MPGLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ only_char: *only_char
+
+Metric:
+ name: RecMPGMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - MGPLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ only_char: *only_char
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'char_label', 'bpe_label', 'wp_label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - MGPLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ only_char: *only_char
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'char_label', 'bpe_label', 'wp_label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/mgpstr/vit_mgpstr_only_char.yml b/configs/rec/mgpstr/vit_mgpstr_only_char.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2bb300499c4ff931e0aa0103308b14e058d6f397
--- /dev/null
+++ b/configs/rec/mgpstr/vit_mgpstr_only_char.yml
@@ -0,0 +1,110 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/vit_mgpstr_only_char/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ use_amp: True
+ save_res_path: ./output/rec/u14m_filter/predicts_vit_mgpstr_only_char.txt
+ grad_clip_val: 5
+
+Optimizer:
+ name: Adam
+ lr: 0.000325 # 4gpus 128bs/gpu
+ weight_decay: 0.
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: MGPSTR
+ Transform:
+ Encoder:
+ name: ViT
+ img_size: [32,128]
+ patch_size: [4, 4]
+ embed_dim: 384
+ depth: 12
+ num_heads: 6
+ mlp_ratio: 4
+ qkv_bias: True
+ Decoder:
+ name: MGPDecoder
+ only_char: &only_char True
+
+Loss:
+ name: MGPLoss
+ only_char: *only_char
+
+PostProcess:
+ name: MPGLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ only_char: *only_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - MGPLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ only_char: *only_char
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'char_label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - MGPLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ only_char: *only_char
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'char_label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/moran/resnet31_lstm_moran.yml b/configs/rec/moran/resnet31_lstm_moran.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f1f53c604c5de2c928e600d7604cb02881a0c628
--- /dev/null
+++ b/configs/rec/moran/resnet31_lstm_moran.yml
@@ -0,0 +1,92 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet31_lstm_moran
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_moran.txt
+ use_amp: True
+ grad_clip_val: 1.0
+
+Optimizer:
+ name: Adam
+ lr: 0.002 # for 1gpus bs1024/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: MORAN
+ Transform:
+ name: MORN
+ target_shape: [32, 128]
+ Encoder:
+ name: ResNet_ASTER
+ Decoder:
+ name: ASTERDecoder
+
+Loss:
+ name: ARLoss
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+PostProcess:
+ name: ARLabelDecode
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 1024
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml b/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ab0d7994e71a1c289f77f4779a113c94ba3b33ff
--- /dev/null
+++ b/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml
@@ -0,0 +1,145 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/focalsvtr_nrtr_maxrtio12
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img: ../ltb/img
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_focalsvtr_nrtr_maxrtio12.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: NRTR
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: FocalSVTR
+ img_size: [32, 128]
+ depths: [6, 6, 6]
+ embed_dim: 96
+ sub_k: [[1, 1], [2, 1], [1, 1]]
+ focal_levels: [3, 3, 3]
+ last_stage: False
+ Decoder:
+ name: NRTRDecoder
+ num_encoder_layers: -1
+ beam_size: 0
+ num_decoder_layers: 2
+ nhead: 12
+ max_len: *max_text_length
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: &padding True
+ padding_rand: True
+ padding_doub: True
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - PARSeqAug:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 12
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: False
+ padding_rand: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ max_ratio: *max_ratio
+ batch_size_per_card: 128
+ num_workers: 4
diff --git a/configs/rec/nrtr/nrtr.yml b/configs/rec/nrtr/nrtr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9a4d738a293b1e95e9ccad59b8e38385e5d5adb2
--- /dev/null
+++ b/configs/rec/nrtr/nrtr.yml
@@ -0,0 +1,107 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/nrtr/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_nrtr.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: BGPD
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: NRTREncoder
+ Decoder:
+ name: NRTRDecoder
+ num_encoder_layers: 6
+ beam_size: 0
+ num_decoder_layers: 6
+ nhead: 8
+ max_len: *max_text_length
+
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/nrtr/svtr_base_nrtr.yml b/configs/rec/nrtr/svtr_base_nrtr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d62abbee5b9976648b36ed55f3517ccc187a7e5f
--- /dev/null
+++ b/configs/rec/nrtr/svtr_base_nrtr.yml
@@ -0,0 +1,118 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtr_base_nrtr/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_base_nrtr.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: NRTR
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRNet
+ img_size: [32, 128]
+ out_char_num: 25
+ out_channels: 256
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: ['Conv','Conv','Conv','Conv','Conv','Conv', 'Conv','Conv', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ last_stage: False
+ prenorm: True
+ Decoder:
+ name: NRTRDecoder
+ num_encoder_layers: -1
+ beam_size: 0
+ num_decoder_layers: 2
+ nhead: 12
+ max_len: *max_text_length
+
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/nrtr/svtr_base_nrtr_syn.yml b/configs/rec/nrtr/svtr_base_nrtr_syn.yml
new file mode 100644
index 0000000000000000000000000000000000000000..bfddad56586382016abc2c625fab14c955ccfc75
--- /dev/null
+++ b/configs/rec/nrtr/svtr_base_nrtr_syn.yml
@@ -0,0 +1,119 @@
+Global:
+ device: gpu
+ epoch_num: 60
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/syn/svtr_base_nrtr/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/syn/predicts_svtr_base_nrtr.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.0005 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: CosineAnnealingLR
+ warmup_epoch: 6
+
+Architecture:
+ model_type: rec
+ algorithm: NRTR
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRNet
+ img_size: [32, 100]
+ out_char_num: 25
+ out_channels: 256
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: ['Conv','Conv','Conv','Conv','Conv','Conv', 'Conv','Conv', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ last_stage: False
+ prenorm: True
+ Decoder:
+ name: NRTRDecoder
+ num_encoder_layers: -1
+ beam_size: 0
+ num_decoder_layers: 6
+ nhead: 12
+ max_len: *max_text_length
+
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: STRLMDBDataSet
+ data_dir: ./
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ # - SVTRRecAug:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - SVTRResize:
+ image_shape: [3, 32, 100]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - SVTRResize:
+ image_shape: [3, 32, 100]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 4
diff --git a/configs/rec/nrtr/svtrv2_nrtr.yml b/configs/rec/nrtr/svtrv2_nrtr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..de74c869296e5674cb90ccc6be84faa2e7b6ce7f
--- /dev/null
+++ b/configs/rec/nrtr/svtrv2_nrtr.yml
@@ -0,0 +1,146 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_nrtr/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_nrtr.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: NRTR
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRNet
+ img_size: [32, 128]
+ out_char_num: 25
+ out_channels: 256
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: ['Conv','Conv','Conv','Conv','Conv','Conv', 'Conv','Conv', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ last_stage: False
+ prenorm: True
+ Decoder:
+ name: NRTRDecoder
+ num_encoder_layers: -1
+ beam_size: 0
+ num_decoder_layers: 2
+ nhead: 12
+ max_len: *max_text_length
+
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/ote/svtr_base_h8_ote.yml b/configs/rec/ote/svtr_base_h8_ote.yml
new file mode 100644
index 0000000000000000000000000000000000000000..571ee717a9e7ada084e8bafada58d0c4fd3a6960
--- /dev/null
+++ b/configs/rec/ote/svtr_base_h8_ote.yml
@@ -0,0 +1,117 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtr_base_h8_ote/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_base_h8_ote.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: OTE
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRNet
+ img_size: [32, 128]
+ out_char_num: 25
+ out_channels: 256
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: ['Conv','Conv','Conv','Conv','Conv','Conv', 'Conv','Conv', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ last_stage: False
+ sub_k: [[1, 1], [2, 1]]
+ prenorm: True
+ Decoder:
+ name: OTEDecoder
+ ar: True
+ num_decoder_layers: 1
+ num_heads: 12
+ max_len: *max_text_length
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/ote/svtr_base_ote.yml b/configs/rec/ote/svtr_base_ote.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8f97c2709dac2adafd3afc79b514e85cca19738c
--- /dev/null
+++ b/configs/rec/ote/svtr_base_ote.yml
@@ -0,0 +1,116 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtr_base_ote/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_base_ote.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: OTE
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRNet
+ img_size: [32, 128]
+ out_char_num: 25
+ out_channels: 256
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: ['Conv','Conv','Conv','Conv','Conv','Conv', 'Conv','Conv', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ last_stage: False
+ prenorm: True
+ Decoder:
+ name: OTEDecoder
+ ar: True
+ num_decoder_layers: 1
+ num_heads: 12
+ max_len: *max_text_length
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml b/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b2601eaefb8a79744806b0c9131416fa758b43a7
--- /dev/null
+++ b/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml
@@ -0,0 +1,140 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/focalsvtr_parseq_maxratio12
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img: ../ltb/img
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_focalsvtr_parseq_maxratio12.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # 4gpus 256bs/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: PARSeq
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: FocalSVTR
+ img_size: [32, 128]
+ depths: [6, 6, 6]
+ embed_dim: 96
+ sub_k: [[1, 1], [2, 1], [1, 1]]
+ focal_levels: [3, 3, 3]
+ last_stage: False
+ Decoder:
+ name: PARSeqDecoder
+ decode_ar: True
+ refine_iters: 1
+
+Loss:
+ name: PARSeqLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: &padding True
+ padding_rand: True
+ padding_doub: True
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 12
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ padding_rand: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ max_ratio: *max_ratio
+ batch_size_per_card: 128
+ num_workers: 4
diff --git a/configs/rec/parseq/svrtv2_parseq.yml b/configs/rec/parseq/svrtv2_parseq.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c1ffa08ce678a533d62b46bfe82c6d130b16ea99
--- /dev/null
+++ b/configs/rec/parseq/svrtv2_parseq.yml
@@ -0,0 +1,136 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_parseq
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ use_amp: True
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_parseq.txt
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # 4gpus 256bs/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: PARSeq
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ Decoder:
+ name: PARSeqDecoder
+ decode_ar: True
+ refine_iters: 1
+ only_attn: False
+
+Loss:
+ name: PARSeqLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/parseq/vit_parseq.yml b/configs/rec/parseq/vit_parseq.yml
new file mode 100644
index 0000000000000000000000000000000000000000..58cb6983c79fd9d7f836174d236d769d7cafdfe3
--- /dev/null
+++ b/configs/rec/parseq/vit_parseq.yml
@@ -0,0 +1,100 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/vit_parseq/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ use_amp: True
+ save_res_path: ./output/rec/u14m_filter/predicts_vit_parseq.txt
+ grad_clip_val: 20
+
+Optimizer:
+ name: AdamW
+ lr: 0.001485 # 2gpus 384bs/gpu
+ weight_decay: 0.
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: PARSeq
+ Transform:
+ Encoder:
+ name: ViT
+ Decoder:
+ name: PARSeqDecoder
+ decode_ar: True
+ refine_iters: 1
+
+Loss:
+ name: PARSeqLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAug:
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 384
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/robustscanner/resnet31_robustscanner.yml b/configs/rec/robustscanner/resnet31_robustscanner.yml
new file mode 100644
index 0000000000000000000000000000000000000000..13ff1274b111fedbe20a237b9dbfd39e1a1f9f42
--- /dev/null
+++ b/configs/rec/robustscanner/resnet31_robustscanner.yml
@@ -0,0 +1,102 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet31_robustscanner
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_resnet31_robustscanner.txt
+ use_amp: True
+
+Optimizer:
+ name: Adam
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: RobustScanner
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: ResNet31
+ # init_type: KaimingNormal
+ Decoder:
+ name: RobustScannerDecoder
+ enc_outchannles: 128
+ hybrid_dec_rnn_layers: 2
+ hybrid_dec_dropout: 0
+ position_dec_rnn_layers: 2
+ mask: True
+ encode_value: False
+ max_text_length: *max_text_length
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - PARSeqAug:
+ - ARLabelEncode: # Class handling label
+ - RobustScannerRecResizeImg:
+ image_shape: [3, 48, 48, 160]
+ width_downsample_ratio: 0.25
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'valid_ratio'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ARLabelEncode: # Class handling label
+ - RobustScannerRecResizeImg:
+ image_shape: [3, 48, 48, 160]
+ width_downsample_ratio: 0.25
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'valid_ratio'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 128
+ num_workers: 2
diff --git a/configs/rec/robustscanner/svtrv2_robustscanner.yml b/configs/rec/robustscanner/svtrv2_robustscanner.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b38bc63a59b96b85345f7c07246af5813cff155c
--- /dev/null
+++ b/configs/rec/robustscanner/svtrv2_robustscanner.yml
@@ -0,0 +1,134 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_robustscanner
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_robustscanner.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: robustscanner
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: RobustScannerDecoder
+ enc_outchannles: 128
+ hybrid_dec_rnn_layers: 2
+ hybrid_dec_dropout: 0
+ position_dec_rnn_layers: 2
+ mask: False
+ encode_value: False
+ max_text_length: *max_text_length
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/sar/resnet31_lstm_sar.yml b/configs/rec/sar/resnet31_lstm_sar.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a690d989a44e7ec2e1481b12fa7afbce5893482e
--- /dev/null
+++ b/configs/rec/sar/resnet31_lstm_sar.yml
@@ -0,0 +1,94 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet31_lstm_sar
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_resnet31_lstm_sar.txt
+ use_amp: True
+ grad_clip_val: 1.0
+
+Optimizer:
+ name: Adam
+ lr: 0.002 # for 4gpus bs256/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SAR
+ Transform:
+ Encoder:
+ name: ResNet31
+ Decoder:
+ name: SARDecoder
+ mask: True
+ use_lstm: True
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - PARSeqAug:
+ - ARLabelEncode: # Class handling label
+ - RobustScannerRecResizeImg:
+ image_shape: [3, 48, 48, 160] # h:48 w:[48,160]
+ width_downsample_ratio: 0.25
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'valid_ratio'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ARLabelEncode: # Class handling label
+ - RobustScannerRecResizeImg:
+ image_shape: [3, 48, 48, 160]
+ width_downsample_ratio: 0.25
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'valid_ratio'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/sar/svtrv2_sar.yml b/configs/rec/sar/svtrv2_sar.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b727d0bedf1a8d9f90e7e81e3ea5791093d0b4f8
--- /dev/null
+++ b/configs/rec/sar/svtrv2_sar.yml
@@ -0,0 +1,128 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_sar
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_sar.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.000325 # for 4gpus bs128/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SAR
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: SARDecoder
+ mask: false
+ use_lstm: false
+
+Loss:
+ name: ARLoss
+
+PostProcess:
+ name: ARLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - ARLabelEncode: # Class handling label
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/seed/resnet31_lstm_seed_tps_on.yml b/configs/rec/seed/resnet31_lstm_seed_tps_on.yml
new file mode 100644
index 0000000000000000000000000000000000000000..61440b6533d5cc6a7903468269f133c9a5be69c1
--- /dev/null
+++ b/configs/rec/seed/resnet31_lstm_seed_tps_on.yml
@@ -0,0 +1,96 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet31_lstm_seed_tps_on
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_aster_tps.txt
+ use_amp: True
+ grad_clip_val: 1.0
+
+Optimizer:
+ name: Adam
+ lr: 0.002 # for 1gpus bs1024/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: aster
+ Transform:
+ name: Aster_TPS
+ tps_inputsize: [32, 64]
+ tps_outputsize: [32, 128]
+ Encoder:
+ name: ResNet_ASTER
+ Decoder:
+ name: ASTERDecoder
+ seed: True
+
+Loss:
+ name: SEEDLoss
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+PostProcess:
+ name: ARLabelDecode
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - Fasttext:
+ path: './cc.en.300.bin' # wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz; gzip -dk cc.en.300.bin.gz
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 1024
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - ARLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/smtr/focalsvtr_smtr.yml b/configs/rec/smtr/focalsvtr_smtr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..232476767b8fdc2ff8437a4f781eaf69450c0cf1
--- /dev/null
+++ b/configs/rec/smtr/focalsvtr_smtr.yml
@@ -0,0 +1,150 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/focalsvtr_smtr_maxratio12
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model: ./output/rec/focalsvtr_smtr/best.pth
+ # ./output/focalnet_subs_nocmff_20ep_u14m_k8_max_ratio12_h8_norand1_h2_padrand_doub_96/best.pth
+ # ./output/rec/focalsvtr_smtr/best.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img: ../ltb/img
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_focalsvtr_smtr_maxratio12.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SMTR
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: FocalSVTR
+ img_size: [32, 128]
+ depths: [6, 6, 6]
+ embed_dim: 96
+ sub_k: [[1, 1], [2, 1], [1, 1]]
+ focal_levels: [3, 3, 3]
+ last_stage: False
+ Decoder:
+ name: SMTRDecoder
+ num_layer: 1
+ ds: True
+ max_len: *max_text_length
+ next_mode: &next True
+ sub_str_len: &subsl 5
+
+Loss:
+ name: SMTRLoss
+
+PostProcess:
+ name: SMTRLabelDecode
+ next_mode: *next
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: &padding True
+ padding_rand: True
+ padding_doub: True
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - PARSeqAug:
+ - SMTRLabelEncode: # Class handling label
+ sub_str_len: *subsl
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_subs', 'label_next', 'length_subs',
+ 'label_subs_pre', 'label_next_pre', 'length_subs_pre', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 12
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: False
+ padding_rand: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ max_ratio: *max_ratio
+ batch_size_per_card: 128
+ num_workers: 4
diff --git a/configs/rec/smtr/focalsvtr_smtr_long.yml b/configs/rec/smtr/focalsvtr_smtr_long.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d0b03d145d7b3a99c5aba754853f4975c9b2e481
--- /dev/null
+++ b/configs/rec/smtr/focalsvtr_smtr_long.yml
@@ -0,0 +1,133 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/focalsvtr_smtr_long
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img: ../ltb/img
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 200
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_focalsvtr_smtr_long.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SMTR
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: FocalSVTR
+ img_size: [32, 128]
+ depths: [6, 6, 6]
+ embed_dim: 96
+ sub_k: [[1, 1], [2, 1], [1, 1]]
+ focal_levels: [3, 3, 3]
+ last_stage: False
+ Decoder:
+ name: SMTRDecoder
+ num_layer: 1
+ ds: True
+ max_len: *max_text_length
+ next_mode: &next True
+ sub_str_len: &subsl 5
+ infer_aug: True
+
+Loss:
+ name: SMTRLoss
+
+PostProcess:
+ name: SMTRLabelDecode
+ next_mode: *next
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: &padding True
+ padding_rand: True
+ padding_doub: True
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - PARSeqAug:
+ - SMTRLabelEncode: # Class handling label
+ sub_str_len: *subsl
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_subs', 'label_next', 'length_subs',
+ 'label_subs_pre', 'label_next_pre', 'length_subs_pre', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 12
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ../ltb/
+ label_file_list: ['../ltb/ultra_long_70_list.txt']
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ARLabelEncode: # Class handling label
+ max_text_length: 200
+ - SliceResize:
+ image_shape: [3, 32, 128]
+ padding: False
+ max_ratio: 12
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1
+ num_workers: 2
diff --git a/configs/rec/smtr/readme.md b/configs/rec/smtr/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..ddda48770f153736228e80a0937ba7daf77bc6a7
--- /dev/null
+++ b/configs/rec/smtr/readme.md
@@ -0,0 +1,183 @@
+# SMTR
+
+- [SMTR](#smtr)
+ - [1. Introduction](#1-introduction)
+ - [2. Environment](#2-environment)
+ - [Dataset Preparation](#dataset-preparation)
+ - [3. Model Training / Evaluation](#3-model-training--evaluation)
+ - [Citation](#citation)
+
+
+
+## 1. Introduction
+
+Paper:
+
+> [Out of Length Text Recognition with Sub-String Matching](https://arxiv.org/abs/2407.12317)
+> Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xieping Gao, Yu-Gang Jiang
+
+
+Scene Text Recognition (STR) methods have demonstrated robust performance in word-level text recognition. However, in applications the text image is sometimes long due to detected with multiple horizontal words. It triggers the requirement to build long text recognition models from readily available short word-level text datasets, which has been less studied previously. In this paper, we term this the Out of Length (OOL) text recognition. We establish the first Long Text Benchmark (LTB) to facilitate the assessment of different methods in long text recognition. Meanwhile, we propose a novel method called OOL Text Recognition with sub-String Matching (SMTR). SMTR comprises two cross-attention-based modules: one encodes a sub-string containing multiple characters into next and previous queries, and the other employs the queries to attend to the image features, matching the sub-string and simultaneously recognizing its next and previous character. SMTR can recognize text of arbitrary length by iterating the process above. To avoid being trapped in recognizing highly similar sub-strings, we introduce a regularization training to compel SMTR to effectively discover subtle differences between similar sub-strings for precise matching. In addition, we propose an inference augmentation to alleviate confusion caused by identical sub-strings and improve the overall recognition efficiency. Extensive experimental results reveal that SMTR, even when trained exclusively on short text, outperforms existing methods in public short text benchmarks and exhibits a clear advantage on LTB.
+
+The accuracy (%) and model files of SMTR on the public dataset of scene text recognition are as follows:
+
+- Syn: Synth dataset(MJ+ST) from [PARSeq](https://github.com/baudm/parseq)
+
+- U14M: Union14M-L from [Union14M](https://github.com/Mountchicken/Union14M/)
+
+- Test on Long Text Benchmark ([Download LTB](https://drive.google.com/drive/folders/1NChdlw7ustbXtlFBmh_0xnHvRkffb9Ge?usp=sharing)):
+
+| Model | Training Data | LTB | Config&Model&Log |
+| :-------: | :-----------: | :--: | :---------------------------------------------------------------------------------------------: |
+| SMTR | Syn | 39.6 | [link](https://drive.google.com/drive/folders/11SplakPPOFDMhPixv7ABNgjeTg4jKyfU?usp=sharing) |
+| SMTR | U14M | 51.0 | [link](https://drive.google.com/drive/folders/1-K5O0d0q9fhY5fJvU6nn5fFFtSMnbE_-?usp=drive_link) |
+| FocalSVTR | U14M | 42.1 | [link](https://drive.google.com/drive/folders/100xF5wFr7xSCVBYM1h_0d_8xv5Qeqobp?usp=sharing) |
+
+- Test on Common Benchmarks from [PARSeq](https://github.com/baudm/parseq):
+
+| Model | Training Data | IC13
857 | SVT | IIIT5k
3000 | IC15
1811 | SVTP | CUTE80 | Avg | Config&Model&Log |
+| :-------: | :-----------: | :----------: | :--: | :-------------: | :-----------: | :--: | :----: | :---: | :---------------------: |
+| SMTR | Syn | 97.4 | 94.9 | 97.4 | 88.4 | 89.9 | 96.2 | 94.02 | Same as the above table |
+| SMTR | U14M | 98.3 | 97.4 | 99.0 | 90.1 | 92.7 | 97.9 | 95.90 | Same as the above table |
+| FocalSVTR | U14M | 97.3 | 96.3 | 98.2 | 87.4 | 88.4 | 96.2 | 93.97 | Same as the above table |
+
+- Test on Union14M-L benchmark from [Union14M](https://github.com/Mountchicken/Union14M/).
+
+| Model | Traing Data | Curve | Multi-
Oriented | Artistic | Contextless | Salient | Multi-
word | General | Avg | Config&Model&Log |
+| :-------: | :---------: | :---: | :-----------------: | :------: | :---------: | :-----: | :-------------: | :-----: | :---: | :---------------------: |
+| SMTR | Syn | 74.2 | 30.6 | 58.5 | 67.6 | 79.6 | 75.1 | 67.9 | 64.79 | Same as the above table |
+| SMTR | U14M | 89.1 | 87.7 | 76.8 | 83.9 | 84.6 | 89.3 | 83.7 | 85.00 | Same as the above table |
+| FocalSVTR | U14M | 77.7 | 62.4 | 65.7 | 78.6 | 71.6 | 81.3 | 79.2 | 73.80 | Same as the above table |
+
+- Training and test on Chinese dataset, from [Chinese Benckmark](https://github.com/FudanVI/benchmarking-chinese-text-recognition).
+
+| Model | Scene | Web | Document | Handwriting | Avg | Config&Model&Log |
+| :-----------: | :---: | :--: | :------: | :---------: | :---: | :---------------------------------------------------------------------------------------------: |
+| SMTR w/o Aug | 79.8 | 80.6 | 99.1 | 61.9 | 80.33 | [link](https://drive.google.com/drive/folders/1v8CK5GIu7wunnD5jFh2bLbusjyHeban5?usp=drive_link) |
+| SMTR w/ Aug | 83.4 | 83.0 | 99.3 | 65.1 | 82.68 | [link](https://drive.google.com/drive/folders/1SQnwSm0bOBQ0eMKKD08F_4Blkjie_3la?usp=drive_link) |
+
+Download all Configs, Models, and Logs from [Google Drive](https://drive.google.com/drive/folders/1dCuaWwCLP9xIHgy-7NtpeDLOvgk9NoKE?usp=drive_link).
+
+
+
+## 2. Environment
+
+- [PyTorch](http://pytorch.org/) version >= 1.13.0
+- Python version >= 3.7
+
+```shell
+git clone -b develop https://github.com/Topdu/OpenOCR.git
+cd OpenOCR
+# A100 Ubuntu 20.04 Cuda 11.8
+conda create -n openocr python==3.8
+conda activate openocr
+conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia
+pip install -r requirements.txt
+```
+
+#### Dataset Preparation
+
+- [English dataset download](https://github.com/baudm/parseq)
+
+- [Union14M-L download](https://github.com/Mountchicken/Union14M)
+
+- [Chinese dataset download](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download)
+
+- [LTB download](https://drive.google.com/drive/folders/1NChdlw7ustbXtlFBmh_0xnHvRkffb9Ge?usp=sharing)
+
+The expected filesystem structure is as follows:
+
+```
+benchmark_bctr
+├── benchmark_bctr_test
+│ ├── document_test
+│ ├── handwriting_test
+│ ├── scene_test
+│ └── web_test
+└── benchmark_bctr_train
+ ├── document_train
+ ├── handwriting_train
+ ├── scene_train
+ └── web_train
+evaluation
+├── CUTE80
+├── IC13_857
+├── IC15_1811
+├── IIIT5k
+├── SVT
+└── SVTP
+OpenOCR
+synth
+├── MJ
+│ ├── test
+│ ├── train
+│ └── val
+└── ST
+test # from PARSeq
+├── ArT
+├── COCOv1.4
+├── CUTE80
+├── IC13_1015
+├── IC13_1095
+├── IC13_857
+├── IC15_1811
+├── IC15_2077
+├── IIIT5k
+├── SVT
+├── SVTP
+└── Uber
+u14m # lmdb format
+├── artistic
+├── contextless
+├── curve
+├── general
+├── multi_oriented
+├── multi_words
+└── salient
+ltb # download link: https://drive.google.com/drive/folders/1NChdlw7ustbXtlFBmh_0xnHvRkffb9Ge?usp=sharing
+Union14M-LMDB-L # lmdb format
+├── train_challenging
+├── train_easy
+├── train_hard
+├── train_medium
+└── train_normal
+```
+
+
+
+## 3. Model Training / Evaluation
+
+Training:
+
+```shell
+# Multi GPU training
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/smtr/focalsvtr_smtr.yml
+# For RTX 4090
+NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/smtr/focalsvtr_smtr.yml
+```
+
+Evaluation:
+
+```shell
+# en
+python tools/eval_rec_all_ratio.py --c configs/rec/smtr/focalsvtr_smtr.yml
+# long text
+python tools/eval_rec_all_long_simple.py --c configs/rec/smtr/focalsvtr_smtr_long.yml
+# ch
+python tools/eval_rec_all_ch.py --c configs/rec/smtr/focalsvtr_smtr_ch.yml
+```
+
+## Citation
+
+```bibtex
+@article{Du2024SMTR,
+ title = {Out of Length Text Recognition with Sub-String Matching},
+ author = {Yongkun Du, Zhineng Chen, Caiyan Jia, Xieping Gao, Yu-Gang Jiang},
+ journal = {CoRR},
+ eprinttype = {arXiv},
+ primaryClass={cs.CV},
+ volume = {abs/2407.12317},
+ year = {2024},
+ url = {https://arxiv.org/abs/2407.12317}
+}
+```
diff --git a/configs/rec/smtr/svtrv2_smtr.yml b/configs/rec/smtr/svtrv2_smtr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e0c57cee2d77d3f03ca3f01d4ac83fbc22eb818d
--- /dev/null
+++ b/configs/rec/smtr/svtrv2_smtr.yml
@@ -0,0 +1,150 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_lnconv_smtr_maxratio12
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img: ../ltb/img
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_lnconv_smtr_maxratio12.txt
+ use_amp: True
+ distributed: true
+
+Optimizer:
+ name: AdamW
+ lr: 0.000325 # for 4gpus bs128/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SMTR
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConv
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: False
+ Decoder:
+ name: SMTRDecoder
+ num_layer: 1
+ ds: True
+ max_len: *max_text_length
+ next_mode: &next True
+ sub_str_len: &subsl 5
+
+Loss:
+ name: SMTRLoss
+
+PostProcess:
+ name: SMTRLabelDecode
+ next_mode: *next
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: false
+ padding_rand: true
+ padding_doub: true
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - PARSeqAug:
+ - SMTRLabelEncode: # Class handling label
+ sub_str_len: *subsl
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_subs', 'label_next', 'length_subs',
+ 'label_subs_pre', 'label_next_pre', 'length_subs_pre', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 12
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/smtr/svtrv2_smtr_bi.yml b/configs/rec/smtr/svtrv2_smtr_bi.yml
new file mode 100644
index 0000000000000000000000000000000000000000..25bc7d36f9278d70c92c9d15790f65eb297cded2
--- /dev/null
+++ b/configs/rec/smtr/svtrv2_smtr_bi.yml
@@ -0,0 +1,136 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_lnconv_smtr_bi
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img: ../ltb/img
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_lnconv_smtr_bi.txt
+ use_amp: True
+ distributed: true
+
+Optimizer:
+ name: AdamW
+ lr: 0.000325
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SMTR
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConv
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: False
+ Decoder:
+ name: SMTRDecoder
+ num_layer: 1
+ ds: True
+ max_len: *max_text_length
+ next_mode: &next True
+ sub_str_len: &subsl 5
+ infer_aug: True
+
+Loss:
+ name: SMTRLoss
+
+PostProcess:
+ name: SMTRLabelDecode
+ next_mode: *next
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: false
+ padding_rand: true
+ padding_doub: true
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - PARSeqAug:
+ - SMTRLabelEncode: # Class handling label
+ sub_str_len: *subsl
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_subs', 'label_next', 'length_subs',
+ 'label_subs_pre', 'label_next_pre', 'length_subs_pre', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 12
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ../ltb/
+ label_file_list: ['../ltb/ultra_long_70_list.txt']
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ARLabelEncode: # Class handling label
+ max_text_length: 200
+ - SliceResize:
+ image_shape: [3, 32, 128]
+ padding: False
+ max_ratio: 12
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1
+ num_workers: 2
diff --git a/configs/rec/srn/resnet50_fpn_srn.yml b/configs/rec/srn/resnet50_fpn_srn.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d62908a3da022d4a548dd0cecbb9d06aee6dc735
--- /dev/null
+++ b/configs/rec/srn/resnet50_fpn_srn.yml
@@ -0,0 +1,97 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet50_fpn_srn
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_resnet50_fpn_srn.txt
+ # find_unused_parameters: True
+ use_amp: True
+ grad_clip_val: 10
+
+Optimizer:
+ name: Adam
+ lr: 0.002 # for 4gpus bs128/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SRN
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: ResNet_FPN
+ layers: 50
+ Decoder:
+ name: SRNDecoder
+ hidden_dims: 512
+
+Loss:
+ name: SRNLoss
+ # smoothing: True
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+PostProcess:
+ name: SRNLabelDecode
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ channel_first: False
+ - PARSeqAugPIL:
+ - SRNLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256] # h:48 w:[48,160]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ channel_first: False
+ - SRNLabelEncode: # Class handling label
+ - RecTVResize:
+ image_shape: [64, 256] # h:48 w:[48,160]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 128
+ num_workers: 2
diff --git a/configs/rec/srn/svtrv2_srn.yml b/configs/rec/srn/svtrv2_srn.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8c1662b3ad2d061c9282db86935b5112eca41fd4
--- /dev/null
+++ b/configs/rec/srn/svtrv2_srn.yml
@@ -0,0 +1,131 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_srn
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: ./tools/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ use_space_char: False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_srn.txt
+ # find_unused_parameters: True
+ use_amp: True
+ grad_clip_val: 10
+
+Optimizer:
+ name: AdamW
+ lr: 0.000325 # for 4gpus bs128/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SRN
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ out_channels: 256
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: SRNDecoder
+ hidden_dims: 384
+
+Loss:
+ name: SRNLoss
+ # smoothing: True
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+PostProcess:
+ name: SRNLabelDecode
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - SRNLabelEncode: # Class handling label
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: ['../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - SRNLabelEncode: # Class handling label
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: 4
+ num_workers: 4
diff --git a/configs/rec/svtrs/convnextv2_ctc.yml b/configs/rec/svtrs/convnextv2_ctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..494f19d1969fe23a8ecb95ba9d299ab6df0f7b9b
--- /dev/null
+++ b/configs/rec/svtrs/convnextv2_ctc.yml
@@ -0,0 +1,105 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/convnextv2_ctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_convnextv2_ctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: ConvNeXtV2
+ out_channels: 256
+ depths: [2, 2, 8, 2]
+ dims: [80, 160, 320, 640]
+ drop_path_rate: 0.1
+ strides: [[4,4], [2,1], [2,1], [1,1]]
+ last_stage: True
+ Decoder:
+ name: CTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/convnextv2_h8_ctc.yml b/configs/rec/svtrs/convnextv2_h8_ctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2cffc12a590e7c1b6ea334309f6316f9f1335310
--- /dev/null
+++ b/configs/rec/svtrs/convnextv2_h8_ctc.yml
@@ -0,0 +1,105 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/convnextv2_h8_ctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_convnextv2_h8_ctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: ConvNeXtV2
+ out_channels: 256
+ depths: [2, 2, 8, 2]
+ dims: [80, 160, 320, 640]
+ drop_path_rate: 0.1
+ strides: [[4,4], [1,1], [2,1], [1,1]]
+ last_stage: True
+ Decoder:
+ name: CTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/convnextv2_h8_rctc.yml b/configs/rec/svtrs/convnextv2_h8_rctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ebe39eb5b44b62544f68034a35c9c54486f9ca0f
--- /dev/null
+++ b/configs/rec/svtrs/convnextv2_h8_rctc.yml
@@ -0,0 +1,106 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/convnextv2_h8_rctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_convnextv2_h8_rctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: ConvNeXtV2
+ out_channels: 256
+ depths: [2, 2, 8, 2]
+ dims: [80, 160, 320, 640]
+ drop_path_rate: 0.1
+ strides: [[4,4], [1,1], [2,1], [1,1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/convnextv2_rctc.yml b/configs/rec/svtrs/convnextv2_rctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c4a416b5b25de899cb26993a5b42fa1827a99a2a
--- /dev/null
+++ b/configs/rec/svtrs/convnextv2_rctc.yml
@@ -0,0 +1,106 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/convnextv2_rctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_convnextv2_rctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: ConvNeXtV2
+ out_channels: 256
+ depths: [2, 2, 8, 2]
+ dims: [80, 160, 320, 640]
+ drop_path_rate: 0.1
+ strides: [[4,4], [2,1], [2,1], [1,1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml b/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f2e76fb2e55ad72196d0efb010f183eff8179d44
--- /dev/null
+++ b/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml
@@ -0,0 +1,105 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/convnextv2_tiny_h8_ctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_convnextv2_h8_ctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: ConvNeXtV2
+ out_channels: 256
+ depths: [3, 3, 9, 3]
+ dims: [96, 192, 384, 768]
+ drop_path_rate: 0.1
+ strides: [[4,4], [1,1], [2,1], [1,1]]
+ last_stage: True
+ Decoder:
+ name: CTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml b/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d1e3d8c6f964f176988b2e42cf3d4a1c41c34688
--- /dev/null
+++ b/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml
@@ -0,0 +1,106 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/convnext_tiny_rctc
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_convnextv2_h8_ctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: ConvNeXtV2
+ out_channels: 256
+ depths: [3, 3, 9, 3]
+ dims: [96, 192, 384, 768]
+ drop_path_rate: 0.1
+ strides: [[4,4], [1,1], [2,1], [1,1]]
+ last_stage: False
+ feat2d: True
+ Decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/crnn_ctc.yml b/configs/rec/svtrs/crnn_ctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b26252fc0655739a565839a9ecc864aa2ab2e193
--- /dev/null
+++ b/configs/rec/svtrs/crnn_ctc.yml
@@ -0,0 +1,99 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/crnn_ctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_crnn_ctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: CRNN
+ Transform:
+ Encoder:
+ name: ResNet_ASTER
+ Decoder:
+ name: CTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/crnn_ctc_long.yml b/configs/rec/svtrs/crnn_ctc_long.yml
new file mode 100644
index 0000000000000000000000000000000000000000..1c4818d5db46f843dd601cda9468b13a4c16f670
--- /dev/null
+++ b/configs/rec/svtrs/crnn_ctc_long.yml
@@ -0,0 +1,116 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/crnn_ctc_long/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_crnn_ctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: CRNN
+ Transform:
+ Encoder:
+ name: ResNet_ASTER
+ Decoder:
+ name: CTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ base_shape: [[32, 32], [64, 32], [96, 32], [128, 32]]
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ max_ratio: 20
+ num_workers: 4
diff --git a/configs/rec/svtrs/focalnet_base_ctc.yml b/configs/rec/svtrs/focalnet_base_ctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ca0bdeace610e736a48582a36b35d27f30d42bb5
--- /dev/null
+++ b/configs/rec/svtrs/focalnet_base_ctc.yml
@@ -0,0 +1,108 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/focalnet_base_ctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/predicts_svtr_tiny.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ # warmup_epoch: 2
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: FocalSVTR
+ img_size: [32, 128]
+ depths: [2, 2, 9, 2]
+ embed_dim: 96
+ sub_k: [[2, 1], [1, 1], [1, 1], [1, 1]]
+ focal_levels: [3, 3, 3, 3]
+ max_khs: [7, 3, 3, 3]
+ focal_windows: [3, 3, 3, 3]
+ out_channels: 384
+ last_stage: True
+ Decoder:
+ name: CTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/focalnet_base_rctc.yml b/configs/rec/svtrs/focalnet_base_rctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..1d2d48977169e61cae8bb7ac1b69d4a12608116f
--- /dev/null
+++ b/configs/rec/svtrs/focalnet_base_rctc.yml
@@ -0,0 +1,109 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/focalnet_base_rctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_tiny.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: FocalSVTR
+ img_size: [32, 128]
+ depths: [2, 2, 9, 2]
+ embed_dim: 96
+ sub_k: [[2, 1], [1, 1], [1, 1], [1, 1]]
+ focal_levels: [3, 3, 3, 3]
+ max_khs: [7, 3, 3, 3]
+ focal_windows: [3, 3, 3, 3]
+ out_channels: 384
+ last_stage: False
+ feat2d: True
+ Decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/focalsvtr_ctc.yml b/configs/rec/svtrs/focalsvtr_ctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..985529bf8bf4ba1d929f867e3962ef16a81542ea
--- /dev/null
+++ b/configs/rec/svtrs/focalsvtr_ctc.yml
@@ -0,0 +1,106 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/focalsvtr_ctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_tiny.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: FocalSVTR
+ img_size: [32, 128]
+ depths: [6, 6, 9]
+ embed_dim: 96
+ sub_k: [[1, 1], [2, 1], [1, 1]]
+ focal_levels: [3, 3, 3]
+ out_channels: 256
+ last_stage: True
+ Decoder:
+ name: CTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/focalsvtr_rctc.yml b/configs/rec/svtrs/focalsvtr_rctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2dcb20a992904643d2f60cd47081ee19bb69bae0
--- /dev/null
+++ b/configs/rec/svtrs/focalsvtr_rctc.yml
@@ -0,0 +1,107 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/focalsvtr_rctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_tiny.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: FocalSVTR
+ img_size: [32, 128]
+ depths: [6, 6, 9]
+ embed_dim: 96
+ sub_k: [[1, 1], [2, 1], [1, 1]]
+ focal_levels: [3, 3, 3]
+ out_channels: 256
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/resnet45_trans_ctc.yml b/configs/rec/svtrs/resnet45_trans_ctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7ba65c147929ec507d662be794510c4af2b34770
--- /dev/null
+++ b/configs/rec/svtrs/resnet45_trans_ctc.yml
@@ -0,0 +1,103 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet45_3en_ctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_tiny.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: ResNet45
+ in_channels: 3
+ strides: [2, 1, 2, 1, 1]
+ last_stage: True
+ trans_layer: 3
+ Decoder:
+ name: CTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/resnet45_trans_rctc.yml b/configs/rec/svtrs/resnet45_trans_rctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2a209d49e0cb62aeac07b496fb86925d1007b4e3
--- /dev/null
+++ b/configs/rec/svtrs/resnet45_trans_rctc.yml
@@ -0,0 +1,104 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet45_3en_rctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_tiny.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.0
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: ResNet45
+ in_channels: 3
+ strides: [2, 1, 2, 1, 1]
+ last_stage: False
+ # out_channels: 256
+ trans_layer: 3
+ Decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/svtr_base_ctc.yml b/configs/rec/svtrs/svtr_base_ctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..46af7c367572ee0759ecc4dbe39a5f1621a6fd3b
--- /dev/null
+++ b/configs/rec/svtrs/svtr_base_ctc.yml
@@ -0,0 +1,110 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtr_base_ctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_tiny.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: SVTRNet
+ img_size: [32, 128]
+ out_char_num: 32
+ out_channels: 256
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: ['Conv','Conv','Conv','Conv','Conv','Conv', 'Conv','Conv', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ last_stage: True
+ prenorm: True
+ Decoder:
+ name: CTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/svtr_base_rctc.yml b/configs/rec/svtrs/svtr_base_rctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ab8cbb12c36cdd14fff20710f1fc3bd3afa1820f
--- /dev/null
+++ b/configs/rec/svtrs/svtr_base_rctc.yml
@@ -0,0 +1,111 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtr_base_rctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_tiny.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: SVTRNet
+ img_size: [32, 128]
+ out_char_num: 32
+ out_channels: 256
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: ['Conv','Conv','Conv','Conv','Conv','Conv', 'Conv','Conv', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ last_stage: False
+ prenorm: True
+ feature2d: True
+ Decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/svtrnet_ctc_syn.yml b/configs/rec/svtrs/svtrnet_ctc_syn.yml
new file mode 100644
index 0000000000000000000000000000000000000000..74bff8e9a3eefdbd0a47630c5d61fcf1422ea643
--- /dev/null
+++ b/configs/rec/svtrs/svtrnet_ctc_syn.yml
@@ -0,0 +1,111 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/syn/svtr_tiny/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/syn/predicts_svtr_tiny.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.0005 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: CosineAnnealingLR
+ warmup_epoch: 2
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: SVTRNet
+ img_size: [32, 100]
+ out_char_num: 25 # W//4 or W//8 or W/12
+ out_channels: 192
+ patch_merging: 'Conv'
+ embed_dim: [64, 128, 256]
+ depth: [3, 6, 3]
+ num_heads: [2, 4, 8]
+ mixer: ['Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[7, 11], [7, 11], [7, 11]]
+ last_stage: True
+ prenorm: False
+ Decoder:
+ name: CTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: STRLMDBDataSet
+ data_dir: ./
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ # - SVTRRecAug:
+ # aug_type: 0 # or 1
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - SVTRResize:
+ image_shape: [3, 32, 100]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 8
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - SVTRResize:
+ image_shape: [3, 32, 100]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/vit_ctc.yml b/configs/rec/svtrs/vit_ctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d526b3ee227de73d545756a2774e5e5209d61
--- /dev/null
+++ b/configs/rec/svtrs/vit_ctc.yml
@@ -0,0 +1,103 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/vit_ctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtr_tiny.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.000325 # for 4gpus bs128/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: ViT
+ img_size: [32, 128]
+ patch_size: [4, 4]
+ last_stage: True
+ feat2d: False
+ Decoder:
+ name: CTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrs/vit_rctc.yml b/configs/rec/svtrs/vit_rctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b4b3375f500989fa65b8776b104d6784ad84923d
--- /dev/null
+++ b/configs/rec/svtrs/vit_rctc.yml
@@ -0,0 +1,103 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/vit_rctc/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_vit_rctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.000325 # for 4gpus bs128/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ Encoder:
+ name: ViT
+ img_size: [32, 128]
+ patch_size: [4, 4]
+ last_stage: False
+ feat2d: True
+ Decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation/
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/svtrv2/readme.md b/configs/rec/svtrv2/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..0711c5b9968b11fd3744ccac9983fd7ffc48df0b
--- /dev/null
+++ b/configs/rec/svtrv2/readme.md
@@ -0,0 +1,143 @@
+# SVTRv2
+
+- [SVTRv2](#svtrv2)
+ - [1. Introduction](#1-introduction)
+ - [1.1 Models and Results](#11-models-and-results)
+ - [2. Environment](#2-environment)
+ - [3. Model Training / Evaluation](#3-model-training--evaluation)
+ - [Dataset Preparation](#dataset-preparation)
+ - [Training](#training)
+ - [Evaluation](#evaluation)
+ - [Inference](#inference)
+ - [Latency Measurement](#latency-measurement)
+ - [Citation](#citation)
+
+
+
+## 1. Introduction
+
+Paper:
+
+> [SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition](./SVTRv2.pdf)
+> Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang
+
+
+Connectionist temporal classification (CTC)-based scene text recognition (STR) methods, e.g., SVTR, are widely employed in OCR applications, mainly due to their simple architecture, which only contains a visual model and a CTC-aligned linear classifier, and therefore fast inference. However, they generally have worse accuracy than encoder-decoder-based methods (EDTRs), particularly in challenging scenarios. In this paper, we propose SVTRv2, a CTC model that beats leading EDTRs in both accuracy and inference speed. SVTRv2 introduces novel upgrades to handle text irregularity and utilize linguistic context, which endows it with the capability to deal with challenging and diverse text instances. First, a multi-size resizing (MSR) strategy is proposed to adaptively resize the text and maintain its readability. Meanwhile, we introduce a feature rearrangement module (FRM) to ensure that visual features accommodate the alignment requirement of CTC well, thus alleviating the alignment puzzle. Second, we propose a semantic guidance module (SGM). It integrates linguistic context into the visual model, allowing it to leverage language information for improved accuracy. Moreover, SGM can be omitted at the inference stage and would not increase the inference cost. We evaluate SVTRv2 in both standard and recent challenging benchmarks, where SVTRv2 is fairly compared with 24 mainstream STR models across multiple scenarios, including different types of text irregularity, languages, and long text. The results indicate that SVTRv2 surpasses all the EDTRs across the scenarios in terms of accuracy and speed.
+
+### 1.1 Models and Results
+
+The accuracy (%) and model files of SVTRv2 on the public dataset of scene text recognition are as follows:
+
+Download all Configs, Models, and Logs from [Google Drive](https://drive.google.com/drive/folders/1i2EZVT-oxfDIDdhwQRm9E6Fk8s6qD3C1?usp=sharing).
+
+| Model | Model size | FPS |
+| :------: | :--------: | :-: |
+| SVTRv2-T | 5.13 | 5.0 |
+| SVTRv2-S | 11.25 | 5.3 |
+| SVTRv2-B | 19.76 | 7.0 |
+
+- Test on Common Benchmarks from [PARSeq](https://github.com/baudm/parseq):
+
+| Model | Training Data | IC13
857 | SVT | IIIT5k
3000 | IC15
1811 | SVTP | CUTE80 | Avg | Config&Model&Log |
+| :------: | :----------------------------------------------------------: | :----------: | :--: | :-------------: | :-----------: | :--: | :----: | :---: | :-----------------------------------------------------------------------: |
+| SVTRv2-B | Synthetic datasets (MJ+ST) | 97.7 | 94.0 | 97.3 | 88.1 | 91.2 | 95.8 | 94.02 | TODO |
+| SVTRv2-T | [Union14M-L-Filter](../../../docs/svtrv2.md#dataset-details) | 98.6 | 96.6 | 98.0 | 88.4 | 90.5 | 96.5 | 94.78 | [Google drive](https://drive.google.com/drive/folders/12ZUGkCS7tEhFhWa2RKKtyB0tPjhH4d9s?usp=drive_link) |
+| SVTRv2-S | [Union14M-L-Filter](../../../docs/svtrv2.md#dataset-details) | 99.0 | 98.3 | 98.5 | 89.5 | 92.9 | 98.6 | 96.13 | [Google drive](https://drive.google.com/drive/folders/1mOG3EUAOsmD16B-VIelVDYf_O64q0G3M?usp=drive_link) |
+| SVTRv2-B | [Union14M-L-Filter](../../../docs/svtrv2.md#dataset-details) | 99.2 | 98.0 | 98.7 | 91.1 | 93.5 | 99.0 | 96.57 | [Google drive](https://drive.google.com/drive/folders/11u11ptDzQ4BF9RRsOYdZnXl6ell2h4jN?usp=drive_link) |
+
+- Test on Union14M-L benchmark from [Union14M](https://github.com/Mountchicken/Union14M/).
+
+| Model | Traing Data | Curve | Multi-
Oriented | Artistic | Contextless | Salient | Multi-
word | General | Avg | Config&Model&Log |
+| :------: | :----------------------------------------------------------: | :---: | :-----------------: | :------: | :---------: | :-----: | :-------------: | :-----: | :---: | :---------------------: |
+| SVTRv2-B | Synthetic datasets (MJ+ST) | 74.6 | 25.2 | 57.6 | 69.7 | 77.9 | 68.0 | 66.9 | 62.83 | Same as the above table |
+| SVTRv2-T | [Union14M-L-Filter](../../../docs/svtrv2.md#dataset-details) | 83.6 | 76.0 | 71.2 | 82.4 | 77.2 | 82.3 | 80.7 | 79.05 | Same as the above table |
+| SVTRv2-S | [Union14M-L-Filter](../../../docs/svtrv2.md#dataset-details) | 88.3 | 84.6 | 76.5 | 84.3 | 83.3 | 85.4 | 83.5 | 83.70 | Same as the above table |
+| SVTRv2-B | [Union14M-L-Filter](../../../docs/svtrv2.md#dataset-details) | 90.6 | 89.0 | 79.3 | 86.1 | 86.2 | 86.7 | 85.1 | 86.14 | Same as the above table |
+
+- Test on [LTB](../smtr/readme.md) and [OST](https://github.com/wangyuxin87/VisionLAN).
+
+| Model | Traing Data | LTB | OST | Config&Model&Log |
+| :------: | :----------------------------------------------------------: | :---: | :--: | :---------------------: |
+| SVTRv2-T | [Union14M-L-Filter](../../../docs/svtrv2.md#dataset-details) | 47.83 | 71.4 | Same as the above table |
+| SVTRv2-S | [Union14M-L-Filter](../../../docs/svtrv2.md#dataset-details) | 47.57 | 78.0 | Same as the above table |
+| SVTRv2-B | [Union14M-L-Filter](../../../docs/svtrv2.md#dataset-details) | 50.23 | 80.0 | Same as the above table |
+
+- Training and test on Chinese dataset, from [Chinese Benckmark](https://github.com/FudanVI/benchmarking-chinese-text-recognition).
+
+| Model | Scene | Web | Document | Handwriting | Avg | Config&Model&Log |
+| :------: | :---: | :--: | :------: | :---------: | :---: | :-----------------------------------------------------------------------------------------------------: |
+| SVTRv2-T | 77.8 | 78.8 | 99.3 | 62.0 | 79.45 | [Google drive](https://drive.google.com/drive/folders/1vqTFonJV83SXVFrGhL31zXq7aOLwjnGD?usp=drive_link) |
+| SVTRv2-S | 81.1 | 81.2 | 99.3 | 65.0 | 81.64 | [Google drive](https://drive.google.com/drive/folders/1X3hqArfvRIRtuYLHDtSQheQmDc_oXpY6?usp=drive_link) |
+| SVTRv2-B | 83.5 | 83.3 | 99.5 | 67.0 | 83.31 | [Google drive](https://drive.google.com/drive/folders/1ZDECKXf8zZFhcKKKpvicg43Ho85uDZkF?usp=drive_link) |
+
+
+
+## 2. Environment
+
+- [PyTorch](http://pytorch.org/) version >= 1.13.0
+- Python version >= 3.7
+
+```shell
+git clone -b develop https://github.com/Topdu/OpenOCR.git
+cd OpenOCR
+# Ubuntu 20.04 Cuda 11.8
+conda create -n openocr python==3.8
+conda activate openocr
+conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia
+pip install -r requirements.txt
+```
+
+
+
+## 3. Model Training / Evaluation
+
+### Dataset Preparation
+
+Referring to [Downloading Datasets](../../../docs/svtrv2.md#downloading-datasets)
+
+### Training
+
+```shell
+# First stage
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_rctc.yml
+
+# Second stage
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml --o Global.pretrained_model=./output/rec/u14m_filter/svtrv2_rctc/best.pth
+
+# For Multi RTX 4090
+NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_rctc.yml
+# 20epoch runs for about 6 hours
+```
+
+### Evaluation
+
+```shell
+# short text: Common, Union14M-Benchmark, OST
+python tools/eval_rec_all_ratio.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml
+```
+
+After a successful run, the results are saved in a csv file in `output_dir` in the config file.
+
+### Inference
+
+```shell
+python tools/infer_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml --o Global.infer_img=/path/img_fold or /path/img_file
+```
+
+### Latency Measurement
+
+Firstly, downloading the IIIT5K images from [Google Drive](https://drive.google.com/drive/folders/1Po1LSBQb87DxGJuAgLNxhsJ-pdXxpIfS?usp=drive_link). Then, running the following command:
+
+```shell
+python tools/infer_rec.py --c configs/rec/SVTRv2/svtrv2_rctc.yml --o Global.infer_img=../iiit5k_test_image
+```
+
+## Citation
+
+```bibtex
+@article{Du2024SVTRv4,
+ title = {SVTRv2: Scene Text Recognition with a Single Visual Model},
+ author = {Yongkun Du, Zhineng Chen*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang},
+ year = {2024}
+}
+```
diff --git a/configs/rec/svtrv2/repsvtr_ch.yml b/configs/rec/svtrv2/repsvtr_ch.yml
new file mode 100644
index 0000000000000000000000000000000000000000..fd683128dcf0355c38f85c21a59f81d7bc6eb5a0
--- /dev/null
+++ b/configs/rec/svtrv2/repsvtr_ch.yml
@@ -0,0 +1,121 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/repsvtr_ch/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model: ./openocr_repsvtr_ch.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/ppocr_keys_v1.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char True
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_ctc.txt
+ use_amp: True
+ project_name: resvtr_ctc_nosgm_ds
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTRv2_mobile
+ Transform:
+ Encoder:
+ name: RepSVTREncoder
+ Decoder:
+ name: CTCDecoder
+ svtr_encoder:
+ dims: 256
+ depth: 2
+ hidden_dims: 256
+ kernel_size: [1, 3]
+ use_guide: True
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ ignore_space: False
+ # is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ base_shape: [[32, 32], [64, 32], [96, 32], [128, 32]]
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecDynamicResize:
+ image_shape: [48, 320]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1
+ num_workers: 4
diff --git a/configs/rec/svtrv2/svtrv2_ch.yml b/configs/rec/svtrv2/svtrv2_ch.yml
new file mode 100644
index 0000000000000000000000000000000000000000..6ee2e8f42977e8a9705523b9ab6dea0bc2f4248e
--- /dev/null
+++ b/configs/rec/svtrv2/svtrv2_ch.yml
@@ -0,0 +1,133 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_ctc_u14m_two33_tvresize/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model: ./openocr_svtrv2_ch.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/ppocr_keys_v1.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char True
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_ctc.txt
+ use_amp: True
+ project_name: svtrv2_ctc_nosgm_ds
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTRv2_server
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ out_channels: 256
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ kernel_sizes: [[5,5,5,5,5,5], [5,5,5,5,5,5], [-1]]
+ num_convs: [[1,1,1,1,1,1], [1,1,1,1,1,1], [-1]]
+ sub_k: [[2, 1], [2, 1], [-1, -1]]
+ last_stage: False
+ feat2d: True
+ pope_bias: True
+ Decoder:
+ name: CTCDecoder
+ svtr_encoder:
+ dims: 256
+ depth: 2
+ hidden_dims: 256
+ kernel_size: [1, 3]
+ use_guide: True
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ ignore_space: False
+ # is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ base_shape: [[32, 32], [64, 32], [96, 32], [128, 32]]
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecDynamicResize:
+ image_shape: [48, 320]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 4
diff --git a/configs/rec/svtrv2/svtrv2_ctc.yml b/configs/rec/svtrv2/svtrv2_ctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..da7b413a7d4186fbf0efe4b59dc91f2e2030e32b
--- /dev/null
+++ b/configs/rec/svtrv2/svtrv2_ctc.yml
@@ -0,0 +1,136 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_ctc/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_ctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTRv2
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ out_channels: 256
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: True
+ feat2d: False
+ Decoder:
+ name: CTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: 4
+ num_workers: 4
diff --git a/configs/rec/svtrv2/svtrv2_rctc.yml b/configs/rec/svtrv2/svtrv2_rctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f5d9ff312855204333a2f03f6b33cdd34e03e21c
--- /dev/null
+++ b/configs/rec/svtrv2/svtrv2_rctc.yml
@@ -0,0 +1,135 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_rctc/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_rctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTRv2
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: 4
+ num_workers: 4
diff --git a/configs/rec/svtrv2/svtrv2_small_rctc.yml b/configs/rec/svtrv2/svtrv2_small_rctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..95329cd90aca384bd28a74416b13651fb53f2b03
--- /dev/null
+++ b/configs/rec/svtrv2/svtrv2_small_rctc.yml
@@ -0,0 +1,135 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_small_rctc/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_rctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTRv2
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [96, 192, 384]
+ depths: [3, 6, 3]
+ num_heads: [3, 6, 12]
+ mixer: [['Conv','Conv','Conv'],['Conv','Conv','Conv','FGlobal','Global','Global'],['Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: 4
+ num_workers: 4
diff --git a/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..aa364920e76a511ccb0b5933a79343038b12f0f8
--- /dev/null
+++ b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml
@@ -0,0 +1,162 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_smtr_gtc_rctc_maxratio12
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ # ./output/rec/u14m_filter/svtrv2_rctc/best.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_smtr_gtc_rctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.000325 # for 4gpus bs128/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTRv2
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: GTCDecoder
+ infer_gtc: True
+ detach: False
+ gtc_decoder:
+ name: SMTRDecoder
+ num_layer: 1
+ ds: True
+ max_len: *max_text_length
+ next_mode: &next True
+ sub_str_len: &subsl 5
+ ctc_decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: GTCLoss
+ ctc_weight: 0.1
+ gtc_loss:
+ name: SMTRLoss
+
+PostProcess:
+ name: GTCLabelDecode
+ gtc_label_decode:
+ name: SMTRLabelDecode
+ next_mode: *next
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecGTCMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - GTCLabelEncode: # Class handling label
+ gtc_label_encode:
+ name: SMTRLabelEncode
+ sub_str_len: *subsl
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_subs', 'label_next', 'length_subs',
+ 'label_subs_pre', 'label_next_pre', 'length_subs_pre', 'length', 'ctc_label', 'ctc_length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 12
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - GTCLabelEncode: # Class handling label
+ gtc_label_encode:
+ name: ARLabelEncode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'ctc_label', 'ctc_length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8d3c3d7dfa98963c364ebf5ebc960ac33ec6e26d
--- /dev/null
+++ b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml
@@ -0,0 +1,153 @@
+Global:
+ device: gpu
+ epoch_num: 100
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/ch/svtrv2_smtr_gtc_rctc_ch
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 2000]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: False
+ pretrained_model:
+ # ./output/rec/u14m_filter/svtrv2_rctc/best.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/ch/predicts_svtrv2_smtr_gtc_rctc_ch.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 5
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTRv2
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: GTCDecoder
+ infer_gtc: True
+ detach: False
+ gtc_decoder:
+ name: SMTRDecoder
+ num_layer: 1
+ ds: True
+ max_len: *max_text_length
+ next_mode: &next True
+ sub_str_len: &subsl 5
+ ctc_decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: GTCLoss
+ ctc_weight: 0.1
+ gtc_loss:
+ name: SMTRLoss
+
+PostProcess:
+ name: GTCLabelDecode
+ gtc_label_decode:
+ name: SMTRLabelDecode
+ next_mode: *next
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecGTCMetric
+ main_indicator: acc
+ # is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../benchmark_bctr/benchmark_bctr_train/document_train',
+ '../benchmark_bctr/benchmark_bctr_train/handwriting_train',
+ '../benchmark_bctr/benchmark_bctr_train/scene_train',
+ '../benchmark_bctr/benchmark_bctr_train/web_train'
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - GTCLabelEncode: # Class handling label
+ gtc_label_encode:
+ name: SMTRLabelEncode
+ sub_str_len: *subsl
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_subs', 'label_next', 'length_subs',
+ 'label_subs_pre', 'label_next_pre', 'length_subs_pre', 'length', 'ctc_label', 'ctc_length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 8
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: ['../benchmark_bctr/benchmark_bctr_test/scene_test']
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - GTCLabelEncode: # Class handling label
+ gtc_label_encode:
+ name: ARLabelEncode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'ctc_label', 'ctc_length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/configs/rec/svtrv2/svtrv2_tiny_rctc.yml b/configs/rec/svtrv2/svtrv2_tiny_rctc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..121163e516a44963a9c1b5cf1d6f09223dab74ee
--- /dev/null
+++ b/configs/rec/svtrv2/svtrv2_tiny_rctc.yml
@@ -0,0 +1,135 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_tiny_rctc/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_rctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTRv2
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [64, 128, 256]
+ depths: [3, 6, 3]
+ num_heads: [2, 4, 8]
+ mixer: [['Conv','Conv','Conv'],['Conv','Conv','Conv','FGlobal','Global','Global'],['Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length']
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: 4
+ num_workers: 4
diff --git a/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml b/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml
new file mode 100644
index 0000000000000000000000000000000000000000..48a8589b9ec66415078d72960c100312b1d50264
--- /dev/null
+++ b/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml
@@ -0,0 +1,103 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet45_trans_visionlan_LA/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ # ./output/rec/u14m_filter/resnet45_trans_visionlan_LF2/best.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_resnet45_trans_visionlan_LA.txt
+ grad_clip_val: 20
+ use_amp: True
+
+Optimizer:
+ name: Adam
+ lr: 0.0002 # for 4gpus bs128/gpu
+ weight_decay: 0.0
+
+LRScheduler:
+ name: MultiStepLR
+ milestones: [12]
+
+Architecture:
+ model_type: rec
+ algorithm: VisionLAN
+ Transform:
+ Encoder:
+ name: ResNet45
+ in_channels: 3
+ strides: [2, 2, 2, 1, 1]
+ Decoder:
+ name: VisionLANDecoder
+ training_step: &training_step 'LA'
+ n_position: 256
+
+Loss:
+ name: VisionLANLoss
+ training_step: *training_step
+
+PostProcess:
+ name: VisionLANLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - VisionLANLabelEncode:
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - VisionLANLabelEncode:
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 128
+ num_workers: 2
diff --git a/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml b/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml
new file mode 100644
index 0000000000000000000000000000000000000000..87511926135c5353896601612d554056207f0b4c
--- /dev/null
+++ b/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml
@@ -0,0 +1,102 @@
+Global:
+ device: gpu
+ epoch_num: 10
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet45_trans_visionlan_LF1/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_resnet45_trans_visionlan_LF1.txt
+ grad_clip_val: 20
+ use_amp: True
+
+Optimizer:
+ name: Adam
+ lr: 0.0002 # for 4gpus bs128/gpu
+ weight_decay: 0.0
+
+LRScheduler:
+ name: MultiStepLR
+ milestones: [6]
+
+Architecture:
+ model_type: rec
+ algorithm: VisionLAN
+ Transform:
+ Encoder:
+ name: ResNet45
+ in_channels: 3
+ strides: [2, 2, 2, 1, 1]
+ Decoder:
+ name: VisionLANDecoder
+ training_step: &training_step 'LF_1'
+ n_position: 256
+
+Loss:
+ name: VisionLANLoss
+ training_step: *training_step
+
+PostProcess:
+ name: VisionLANLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - VisionLANLabelEncode:
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - VisionLANLabelEncode:
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 128
+ num_workers: 2
diff --git a/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml b/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b595114f2256ecb797234f55ccc0ad672d1dcbee
--- /dev/null
+++ b/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml
@@ -0,0 +1,103 @@
+Global:
+ device: gpu
+ epoch_num: 10
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/resnet45_trans_visionlan_LF2/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ # ./output/rec/u14m_filter/resnet45_trans_visionlan_LF1/best.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_resnet45_trans_visionlan_LF2.txt
+ grad_clip_val: 20
+ use_amp: True
+
+Optimizer:
+ name: Adam
+ lr: 0.0002 # for 4gpus bs128/gpu
+ weight_decay: 0.0
+
+LRScheduler:
+ name: MultiStepLR
+ milestones: [6]
+
+Architecture:
+ model_type: rec
+ algorithm: VisionLAN
+ Transform:
+ Encoder:
+ name: ResNet45
+ in_channels: 3
+ strides: [2, 2, 2, 1, 1]
+ Decoder:
+ name: VisionLANDecoder
+ training_step: &training_step 'LF_2'
+ n_position: 256
+
+Loss:
+ name: VisionLANLoss
+ training_step: *training_step
+
+PostProcess:
+ name: VisionLANLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - VisionLANLabelEncode:
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - VisionLANLabelEncode:
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 128
+ num_workers: 2
diff --git a/configs/rec/visionlan/svtrv2_visionlan_LA.yml b/configs/rec/visionlan/svtrv2_visionlan_LA.yml
new file mode 100644
index 0000000000000000000000000000000000000000..74e0f2b0f3e94f16186ed0443a64f3578f7bf7e6
--- /dev/null
+++ b/configs/rec/visionlan/svtrv2_visionlan_LA.yml
@@ -0,0 +1,112 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_visionlan_LA/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ # ./output/rec/u14m_filter/svtrv2_visionlan_LF2/best.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_visionlan_LA.txt
+ grad_clip_val: 20
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: VisionLAN
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: VisionLANDecoder
+ training_step: &training_step 'LA'
+ n_position: 128
+
+Loss:
+ name: VisionLANLoss
+ training_step: *training_step
+
+PostProcess:
+ name: VisionLANLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - VisionLANLabelEncode:
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - VisionLANLabelEncode:
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml b/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml
new file mode 100644
index 0000000000000000000000000000000000000000..735f2f2fd5f312e42bac3907d090fc89fe2f3e03
--- /dev/null
+++ b/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml
@@ -0,0 +1,111 @@
+Global:
+ device: gpu
+ epoch_num: 10
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_visionlan_LF1/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_visionlan_LF1.txt
+ grad_clip_val: 20
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: VisionLAN
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: VisionLANDecoder
+ training_step: &training_step 'LF_1'
+ n_position: 128
+
+Loss:
+ name: VisionLANLoss
+ training_step: *training_step
+
+PostProcess:
+ name: VisionLANLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - VisionLANLabelEncode:
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - VisionLANLabelEncode:
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml b/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8974de9842bdb43ffe8d52c4278e6070b32517af
--- /dev/null
+++ b/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml
@@ -0,0 +1,112 @@
+Global:
+ device: gpu
+ epoch_num: 10
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_visionlan_LF2/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ # ./output/rec/u14m_filter/svtrv2_visionlan_LF1/best.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_visionlan_LF2.txt
+ grad_clip_val: 20
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.00065 # for 4gpus bs256/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: VisionLAN
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: VisionLANDecoder
+ training_step: &training_step 'LF_2'
+ n_position: 128
+
+Loss:
+ name: VisionLANLoss
+ training_step: *training_step
+
+PostProcess:
+ name: VisionLANLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../Union14M-L-LMDB-Filtered
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - VisionLANLabelEncode:
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ../evaluation
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - VisionLANLabelEncode:
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - RecTVResize:
+ image_shape: [32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/opendet/modeling/__init__.py b/opendet/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..92cbe88b9469e2293a08568cc4325f5136f948ce
--- /dev/null
+++ b/opendet/modeling/__init__.py
@@ -0,0 +1,11 @@
+import copy
+
+from .base_detector import BaseDetector
+
+__all__ = ['build_model']
+
+
+def build_model(config):
+ config = copy.deepcopy(config)
+ det_model = BaseDetector(config)
+ return det_model
diff --git a/opendet/modeling/__pycache__/__init__.cpython-38.pyc b/opendet/modeling/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ec86cb248c93c1201f83423261bae711c288a63
Binary files /dev/null and b/opendet/modeling/__pycache__/__init__.cpython-38.pyc differ
diff --git a/opendet/modeling/__pycache__/base_detector.cpython-38.pyc b/opendet/modeling/__pycache__/base_detector.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d586b0d15a0b30f79e85447a4d039e25a578a82
Binary files /dev/null and b/opendet/modeling/__pycache__/base_detector.cpython-38.pyc differ
diff --git a/opendet/modeling/backbones/__init__.py b/opendet/modeling/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5387fb858f7ce70b377ad4b2b87ce01a8aa5efc7
--- /dev/null
+++ b/opendet/modeling/backbones/__init__.py
@@ -0,0 +1,14 @@
+__all__ = ['build_backbone']
+
+
+def build_backbone(config):
+ # det backbone
+ from .repvit import RepSVTR_det
+
+ support_dict = ['RepSVTR_det']
+
+ module_name = config.pop('name')
+ assert module_name in support_dict, Exception(
+ 'head only support {}'.format(support_dict))
+ module_class = eval(module_name)(**config)
+ return module_class
diff --git a/opendet/modeling/backbones/__pycache__/__init__.cpython-38.pyc b/opendet/modeling/backbones/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f16e0da6eca10e58a4a25e753e067548b786ef6e
Binary files /dev/null and b/opendet/modeling/backbones/__pycache__/__init__.cpython-38.pyc differ
diff --git a/opendet/modeling/backbones/__pycache__/repvit.cpython-38.pyc b/opendet/modeling/backbones/__pycache__/repvit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f5cba22aca3fe8b4743eaeb5475bed681139e7e
Binary files /dev/null and b/opendet/modeling/backbones/__pycache__/repvit.cpython-38.pyc differ
diff --git a/opendet/modeling/backbones/repvit.py b/opendet/modeling/backbones/repvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f73e91be97c593a970116f1a81f2f9fc201af47
--- /dev/null
+++ b/opendet/modeling/backbones/repvit.py
@@ -0,0 +1,340 @@
+"""
+This code is refer from:
+https://github.com/THU-MIG/RepViT
+"""
+
+import torch.nn as nn
+import torch
+from torch.nn.init import constant_
+
+
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
+ min_value = min_value or divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < round_limit * v:
+ new_v += divisor
+ return new_v
+
+
+class SEModule(nn.Module):
+ """SE Module as defined in original SE-Nets with a few additions
+ Additions include:
+ * divisor can be specified to keep channels % div == 0 (default: 8)
+ * reduction channels can be specified directly by arg (if rd_channels is set)
+ * reduction channels can be specified by float rd_ratio (default: 1/16)
+ * global max pooling can be added to the squeeze aggregation
+ * customizable activation, normalization, and gate layer
+ """
+
+ def __init__(
+ self,
+ channels,
+ rd_ratio=1.0 / 16,
+ rd_channels=None,
+ rd_divisor=8,
+ act_layer=nn.ReLU,
+ ):
+ super(SEModule, self).__init__()
+ if not rd_channels:
+ rd_channels = make_divisible(channels * rd_ratio,
+ rd_divisor,
+ round_limit=0.0)
+ self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True)
+ self.act = act_layer()
+ self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True)
+
+ def forward(self, x):
+ x_se = x.mean((2, 3), keepdim=True)
+ x_se = self.fc1(x_se)
+ x_se = self.act(x_se)
+ x_se = self.fc2(x_se)
+ return x * torch.sigmoid(x_se)
+
+
+class Conv2D_BN(nn.Sequential):
+
+ def __init__(
+ self,
+ a,
+ b,
+ ks=1,
+ stride=1,
+ pad=0,
+ dilation=1,
+ groups=1,
+ bn_weight_init=1,
+ resolution=-10000,
+ ):
+ super().__init__()
+ self.add_module(
+ 'c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups,
+ bias=False))
+ self.add_module('bn', nn.BatchNorm2d(b))
+ constant_(self.bn.weight, bn_weight_init)
+ constant_(self.bn.bias, 0)
+
+ @torch.no_grad()
+ def fuse(self):
+ c, bn = self._modules.values()
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
+ w = c.weight * w[:, None, None, None]
+ b = bn.bias - bn.running_mean * bn.weight / \
+ (bn.running_var + bn.eps)**0.5
+ m = nn.Conv2d(w.size(1) * self.c.groups,
+ w.size(0),
+ w.shape[2:],
+ stride=self.c.stride,
+ padding=self.c.padding,
+ dilation=self.c.dilation,
+ groups=self.c.groups,
+ device=c.weight.device)
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+
+class Residual(torch.nn.Module):
+
+ def __init__(self, m, drop=0.):
+ super().__init__()
+ self.m = m
+ self.drop = drop
+
+ def forward(self, x):
+ if self.training and self.drop > 0:
+ return x + self.m(x) * torch.rand(
+ x.size(0), 1, 1, 1, device=x.device).ge_(
+ self.drop).div(1 - self.drop).detach()
+ else:
+ return x + self.m(x)
+
+ @torch.no_grad()
+ def fuse(self):
+ if isinstance(self.m, Conv2D_BN):
+ m = self.m.fuse()
+ assert (m.groups == m.in_channels)
+ identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
+ identity = nn.functional.pad(identity, [1, 1, 1, 1])
+ m.weight += identity.to(m.weight.device)
+ return m
+ elif isinstance(self.m, nn.Conv2d):
+ m = self.m
+ assert (m.groups != m.in_channels)
+ identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
+ identity = nn.functional.pad(identity, [1, 1, 1, 1])
+ m.weight += identity.to(m.weight.device)
+ return m
+ else:
+ return self
+
+
+class RepVGGDW(nn.Module):
+
+ def __init__(self, ed) -> None:
+ super().__init__()
+ self.conv = Conv2D_BN(ed, ed, 3, 1, 1, groups=ed)
+ self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
+ self.dim = ed
+ self.bn = nn.BatchNorm2d(ed)
+
+ def forward(self, x):
+ return self.bn((self.conv(x) + self.conv1(x)) + x)
+
+ @torch.no_grad()
+ def fuse(self):
+ conv = self.conv.fuse()
+ conv1 = self.conv1
+
+ conv_w = conv.weight
+ conv_b = conv.bias
+ conv1_w = conv1.weight
+ conv1_b = conv1.bias
+
+ conv1_w = nn.functional.pad(conv1_w, [1, 1, 1, 1])
+
+ identity = nn.functional.pad(
+ torch.ones(conv1_w.shape[0],
+ conv1_w.shape[1],
+ 1,
+ 1,
+ device=conv1_w.device), [1, 1, 1, 1])
+
+ final_conv_w = conv_w + conv1_w + identity
+ final_conv_b = conv_b + conv1_b
+
+ conv.weight.data.copy_(final_conv_w)
+ conv.bias.data.copy_(final_conv_b)
+
+ bn = self.bn
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
+ w = conv.weight * w[:, None, None, None]
+ b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
+ (bn.running_var + bn.eps)**0.5
+ conv.weight.data.copy_(w)
+ conv.bias.data.copy_(b)
+ return conv
+
+
+class RepViTBlock(nn.Module):
+
+ def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se,
+ use_hs):
+ super(RepViTBlock, self).__init__()
+
+ self.identity = stride == 1 and inp == oup
+ assert hidden_dim == 2 * inp
+
+ if stride != 1:
+ self.token_mixer = nn.Sequential(
+ Conv2D_BN(inp,
+ inp,
+ kernel_size,
+ stride, (kernel_size - 1) // 2,
+ groups=inp),
+ SEModule(inp, 0.25) if use_se else nn.Identity(),
+ Conv2D_BN(inp, oup, ks=1, stride=1, pad=0),
+ )
+ self.channel_mixer = Residual(
+ nn.Sequential(
+ # pw
+ Conv2D_BN(oup, 2 * oup, 1, 1, 0),
+ nn.GELU() if use_hs else nn.GELU(),
+ # pw-linear
+ Conv2D_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
+ ))
+ else:
+ assert self.identity
+ self.token_mixer = nn.Sequential(
+ RepVGGDW(inp),
+ SEModule(inp, 0.25) if use_se else nn.Identity(),
+ )
+ self.channel_mixer = Residual(
+ nn.Sequential(
+ # pw
+ Conv2D_BN(inp, hidden_dim, 1, 1, 0),
+ nn.GELU() if use_hs else nn.GELU(),
+ # pw-linear
+ Conv2D_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
+ ))
+
+ def forward(self, x):
+ return self.channel_mixer(self.token_mixer(x))
+
+
+class RepViT(nn.Module):
+
+ def __init__(self, cfgs, in_channels=3, out_indices=None):
+ super(RepViT, self).__init__()
+ # setting of inverted residual blocks
+ self.cfgs = cfgs
+
+ # building first layer
+ input_channel = self.cfgs[0][2]
+ patch_embed = nn.Sequential(
+ Conv2D_BN(in_channels, input_channel // 2, 3, 2, 1),
+ nn.GELU(),
+ Conv2D_BN(input_channel // 2, input_channel, 3, 2, 1),
+ )
+ layers = [patch_embed]
+ # building inverted residual blocks
+ block = RepViTBlock
+ for k, t, c, use_se, use_hs, s in self.cfgs:
+ output_channel = _make_divisible(c, 8)
+ exp_size = _make_divisible(input_channel * t, 8)
+ layers.append(
+ block(input_channel, exp_size, output_channel, k, s, use_se,
+ use_hs))
+ input_channel = output_channel
+ self.features = nn.ModuleList(layers)
+ self.out_indices = out_indices
+ if out_indices is not None:
+ self.out_channels = [self.cfgs[ids - 1][2] for ids in out_indices]
+ else:
+ self.out_channels = self.cfgs[-1][2]
+
+ def forward(self, x):
+ if self.out_indices is not None:
+ return self.forward_det(x)
+ return self.forward_rec(x)
+
+ def forward_det(self, x):
+ outs = []
+ for i, f in enumerate(self.features):
+ x = f(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return outs
+
+ def forward_rec(self, x):
+ for f in self.features:
+ x = f(x)
+ h = x.shape[2]
+ x = nn.functional.avg_pool2d(x, [h, 2])
+ return x
+
+
+def RepSVTR(in_channels=3):
+ """
+ Constructs a MobileNetV3-Large model
+ """
+ # k, t, c, SE, HS, s
+ cfgs = [
+ [3, 2, 96, 1, 0, 1],
+ [3, 2, 96, 0, 0, 1],
+ [3, 2, 96, 0, 0, 1],
+ [3, 2, 192, 0, 1, (2, 1)],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 384, 0, 1, (2, 1)],
+ [3, 2, 384, 1, 1, 1],
+ [3, 2, 384, 0, 1, 1],
+ ]
+ return RepViT(cfgs, in_channels=in_channels)
+
+
+def RepSVTR_det(in_channels=3, out_indices=[2, 5, 10, 13]):
+ """
+ Constructs a MobileNetV3-Large model
+ """
+ # k, t, c, SE, HS, s
+ cfgs = [
+ [3, 2, 48, 1, 0, 1],
+ [3, 2, 48, 0, 0, 1],
+ [3, 2, 96, 0, 0, 2],
+ [3, 2, 96, 1, 0, 1],
+ [3, 2, 96, 0, 0, 1],
+ [3, 2, 192, 0, 1, 2],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 384, 0, 1, 2],
+ [3, 2, 384, 1, 1, 1],
+ [3, 2, 384, 0, 1, 1],
+ ]
+ return RepViT(cfgs, in_channels=in_channels, out_indices=out_indices)
diff --git a/opendet/modeling/base_detector.py b/opendet/modeling/base_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ae14550973dc66cb398b2814724e49b4d33b9d0
--- /dev/null
+++ b/opendet/modeling/base_detector.py
@@ -0,0 +1,69 @@
+import torch
+from torch import nn
+
+from opendet.modeling.backbones import build_backbone
+from opendet.modeling.necks import build_neck
+from opendet.modeling.heads import build_head
+
+__all__ = ['BaseDetector']
+
+
+class BaseDetector(nn.Module):
+
+ def __init__(self, config):
+ """the module for OCR.
+
+ args:
+ config (dict): the super parameters for module.
+ """
+ super(BaseDetector, self).__init__()
+ in_channels = config.get('in_channels', 3)
+ self.use_wd = config.get('use_wd', True)
+
+ # build backbone
+ if 'Backbone' not in config or config['Backbone'] is None:
+ self.use_backbone = False
+ else:
+ self.use_backbone = True
+ config['Backbone']['in_channels'] = in_channels
+ self.backbone = build_backbone(config['Backbone'])
+ in_channels = self.backbone.out_channels
+
+ # build neck
+ if 'Neck' not in config or config['Neck'] is None:
+ self.use_neck = False
+ else:
+ self.use_neck = True
+ config['Neck']['in_channels'] = in_channels
+ self.neck = build_neck(config['Neck'])
+ in_channels = self.neck.out_channels
+
+ # build head
+ if 'Head' not in config or config['Head'] is None:
+ self.use_head = False
+ else:
+ self.use_head = True
+ config['Head']['in_channels'] = in_channels
+ self.head = build_head(config['Head'])
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ if self.use_wd:
+ if hasattr(self.backbone, 'no_weight_decay'):
+ no_weight_decay = self.backbone.no_weight_decay()
+ else:
+ no_weight_decay = {}
+ if hasattr(self.head, 'no_weight_decay'):
+ no_weight_decay.update(self.head.no_weight_decay())
+ return no_weight_decay
+ else:
+ return {}
+
+ def forward(self, x, data=None):
+ if self.use_backbone:
+ x = self.backbone(x)
+ if self.use_neck:
+ x = self.neck(x)
+ if self.use_head:
+ x = self.head(x, data=data)
+ return x
diff --git a/opendet/modeling/heads/__init__.py b/opendet/modeling/heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..784ba182266af3dc9156c3674de5abd673a69995
--- /dev/null
+++ b/opendet/modeling/heads/__init__.py
@@ -0,0 +1,14 @@
+__all__ = ['build_head']
+
+
+def build_head(config):
+ # det backbone
+ from .db_head import DBHead
+
+ support_dict = ['DBHead']
+
+ module_name = config.pop('name')
+ assert module_name in support_dict, Exception(
+ 'det head only support {}'.format(support_dict))
+ module_class = eval(module_name)(**config)
+ return module_class
diff --git a/opendet/modeling/heads/__pycache__/__init__.cpython-38.pyc b/opendet/modeling/heads/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bff6890388af08a6502e161da6039963fb94866e
Binary files /dev/null and b/opendet/modeling/heads/__pycache__/__init__.cpython-38.pyc differ
diff --git a/opendet/modeling/heads/__pycache__/db_head.cpython-38.pyc b/opendet/modeling/heads/__pycache__/db_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..265caf5d0e5b7258e2a5634ab71dcf990254b227
Binary files /dev/null and b/opendet/modeling/heads/__pycache__/db_head.cpython-38.pyc differ
diff --git a/opendet/modeling/heads/db_head.py b/opendet/modeling/heads/db_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c56accb2fbb135762a9fd3e651923bb3709293b
--- /dev/null
+++ b/opendet/modeling/heads/db_head.py
@@ -0,0 +1,73 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class Head(nn.Module):
+
+ def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs):
+ super(Head, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=in_channels // 4,
+ kernel_size=kernel_list[0],
+ padding=int(kernel_list[0] // 2),
+ bias=False,
+ )
+ self.conv_bn1 = nn.BatchNorm2d(num_features=in_channels // 4, )
+
+ self.conv2 = nn.ConvTranspose2d(
+ in_channels=in_channels // 4,
+ out_channels=in_channels // 4,
+ kernel_size=kernel_list[1],
+ stride=2,
+ )
+ self.conv_bn2 = nn.BatchNorm2d(num_features=in_channels // 4, )
+ self.conv3 = nn.ConvTranspose2d(
+ in_channels=in_channels // 4,
+ out_channels=1,
+ kernel_size=kernel_list[2],
+ stride=2,
+ )
+
+ def forward(self, x, return_f=False):
+ x = self.conv1(x)
+ x = F.relu(self.conv_bn1(x))
+ x = self.conv2(x)
+ x = F.relu(self.conv_bn2(x))
+ if return_f is True:
+ f = x
+ x = self.conv3(x)
+ x = torch.sigmoid(x)
+ if return_f is True:
+ return x, f
+ return x
+
+
+class DBHead(nn.Module):
+ """
+ Differentiable Binarization (DB) for text detection:
+ see https://arxiv.org/abs/1911.08947
+ args:
+ params(dict): super parameters for build DB network
+ """
+
+ def __init__(self, in_channels, k=50, **kwargs):
+ super(DBHead, self).__init__()
+ self.k = k
+ self.binarize = Head(in_channels, **kwargs)
+ self.thresh = Head(in_channels, **kwargs)
+
+ def step_function(self, x, y):
+ return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
+
+ def forward(self, x, data=None):
+ shrink_maps = self.binarize(x)
+ if not self.training:
+ return {'maps': shrink_maps}
+
+ threshold_maps = self.thresh(x)
+ binary_maps = self.step_function(shrink_maps, threshold_maps)
+ y = torch.concat([shrink_maps, threshold_maps, binary_maps], dim=1)
+ return {'maps': y}
diff --git a/opendet/modeling/necks/__init__.py b/opendet/modeling/necks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2deccf84b2f8823f2716e977fab5f0c9a86349b6
--- /dev/null
+++ b/opendet/modeling/necks/__init__.py
@@ -0,0 +1,14 @@
+__all__ = ['build_neck']
+
+
+def build_neck(config):
+ # det backbone
+ from .db_fpn import RSEFPN
+
+ support_dict = ['RSEFPN']
+
+ module_name = config.pop('name')
+ assert module_name in support_dict, Exception(
+ 'det neck only support {}'.format(support_dict))
+ module_class = eval(module_name)(**config)
+ return module_class
diff --git a/opendet/modeling/necks/__pycache__/__init__.cpython-38.pyc b/opendet/modeling/necks/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b535c75bae3ffc341836d0212591f3c5eb551ab9
Binary files /dev/null and b/opendet/modeling/necks/__pycache__/__init__.cpython-38.pyc differ
diff --git a/opendet/modeling/necks/__pycache__/db_fpn.cpython-38.pyc b/opendet/modeling/necks/__pycache__/db_fpn.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebda0a37375e4a7fde175f5a374cae5ff3e0b33f
Binary files /dev/null and b/opendet/modeling/necks/__pycache__/db_fpn.cpython-38.pyc differ
diff --git a/opendet/modeling/necks/db_fpn.py b/opendet/modeling/necks/db_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a858911651233fd516cd7a3147e4cec8fbd64024
--- /dev/null
+++ b/opendet/modeling/necks/db_fpn.py
@@ -0,0 +1,609 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class SEModule(nn.Module):
+
+ def __init__(self, in_channels, reduction=4):
+ super(SEModule, self).__init__()
+
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=in_channels // reduction,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.conv2 = nn.Conv2d(
+ in_channels=in_channels // reduction,
+ out_channels=in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, inputs):
+ outputs = self.avg_pool(inputs)
+ outputs = self.conv1(outputs)
+ outputs = F.relu(outputs)
+ outputs = self.conv2(outputs)
+ outputs = F.hardsigmoid(outputs)
+ return inputs * outputs
+
+
+class IntraCLBlock(nn.Module):
+
+ def __init__(self, in_channels=96, reduce_factor=4):
+ super(IntraCLBlock, self).__init__()
+ self.channels = in_channels
+ self.rf = reduce_factor
+ # weight_attr = paddle.nn.initializer.KaimingUniform()
+ self.conv1x1_reduce_channel = nn.Conv2d(self.channels,
+ self.channels // self.rf,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.conv1x1_return_channel = nn.Conv2d(self.channels // self.rf,
+ self.channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.v_layer_7x1 = nn.Conv2d(
+ self.channels // self.rf,
+ self.channels // self.rf,
+ kernel_size=(7, 1),
+ stride=(1, 1),
+ padding=(3, 0),
+ )
+ self.v_layer_5x1 = nn.Conv2d(
+ self.channels // self.rf,
+ self.channels // self.rf,
+ kernel_size=(5, 1),
+ stride=(1, 1),
+ padding=(2, 0),
+ )
+ self.v_layer_3x1 = nn.Conv2d(
+ self.channels // self.rf,
+ self.channels // self.rf,
+ kernel_size=(3, 1),
+ stride=(1, 1),
+ padding=(1, 0),
+ )
+
+ self.q_layer_1x7 = nn.Conv2d(
+ self.channels // self.rf,
+ self.channels // self.rf,
+ kernel_size=(1, 7),
+ stride=(1, 1),
+ padding=(0, 3),
+ )
+ self.q_layer_1x5 = nn.Conv2d(
+ self.channels // self.rf,
+ self.channels // self.rf,
+ kernel_size=(1, 5),
+ stride=(1, 1),
+ padding=(0, 2),
+ )
+ self.q_layer_1x3 = nn.Conv2d(
+ self.channels // self.rf,
+ self.channels // self.rf,
+ kernel_size=(1, 3),
+ stride=(1, 1),
+ padding=(0, 1),
+ )
+
+ # base
+ self.c_layer_7x7 = nn.Conv2d(
+ self.channels // self.rf,
+ self.channels // self.rf,
+ kernel_size=(7, 7),
+ stride=(1, 1),
+ padding=(3, 3),
+ )
+ self.c_layer_5x5 = nn.Conv2d(
+ self.channels // self.rf,
+ self.channels // self.rf,
+ kernel_size=(5, 5),
+ stride=(1, 1),
+ padding=(2, 2),
+ )
+ self.c_layer_3x3 = nn.Conv2d(
+ self.channels // self.rf,
+ self.channels // self.rf,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1),
+ )
+
+ self.bn = nn.BatchNorm2d(self.channels)
+ self.relu = nn.ReLU()
+
+ def forward(self, x):
+ x_new = self.conv1x1_reduce_channel(x)
+
+ x_7_c = self.c_layer_7x7(x_new)
+ x_7_v = self.v_layer_7x1(x_new)
+ x_7_q = self.q_layer_1x7(x_new)
+ x_7 = x_7_c + x_7_v + x_7_q
+
+ x_5_c = self.c_layer_5x5(x_7)
+ x_5_v = self.v_layer_5x1(x_7)
+ x_5_q = self.q_layer_1x5(x_7)
+ x_5 = x_5_c + x_5_v + x_5_q
+
+ x_3_c = self.c_layer_3x3(x_5)
+ x_3_v = self.v_layer_3x1(x_5)
+ x_3_q = self.q_layer_1x3(x_5)
+ x_3 = x_3_c + x_3_v + x_3_q
+
+ x_relation = self.conv1x1_return_channel(x_3)
+
+ x_relation = self.bn(x_relation)
+ x_relation = self.relu(x_relation)
+
+ return x + x_relation
+
+
+class DSConv(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ padding,
+ stride=1,
+ groups=None,
+ if_act=True,
+ act='relu',
+ **kwargs,
+ ):
+ super(DSConv, self).__init__()
+ if groups is None:
+ groups = in_channels
+ self.if_act = if_act
+ self.act = act
+ self.conv1 = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ bias=False,
+ )
+
+ self.bn1 = nn.BatchNorm2d(num_features=in_channels)
+
+ self.conv2 = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=int(in_channels * 4),
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ )
+
+ self.bn2 = nn.BatchNorm2d(num_features=int(in_channels * 4))
+
+ self.conv3 = nn.Conv2d(
+ in_channels=int(in_channels * 4),
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ )
+ self._c = [in_channels, out_channels]
+ if in_channels != out_channels:
+ self.conv_end = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ )
+
+ def forward(self, inputs):
+ x = self.conv1(inputs)
+ x = self.bn1(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ if self.if_act:
+ if self.act == 'relu':
+ x = F.relu(x)
+ elif self.act == 'hardswish':
+ x = F.hardswish(x)
+ else:
+ print('The activation function({}) is selected incorrectly.'.
+ format(self.act))
+ exit()
+
+ x = self.conv3(x)
+ if self._c[0] != self._c[1]:
+ x = x + self.conv_end(inputs)
+ return x
+
+
+class DBFPN(nn.Module):
+
+ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs):
+ super(DBFPN, self).__init__()
+ self.out_channels = out_channels
+ self.use_asf = use_asf
+ # weight_attr = paddle.nn.initializer.KaimingUniform()
+
+ self.in2_conv = nn.Conv2d(
+ in_channels=in_channels[0],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ bias=False,
+ )
+ self.in3_conv = nn.Conv2d(
+ in_channels=in_channels[1],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ bias=False,
+ )
+ self.in4_conv = nn.Conv2d(
+ in_channels=in_channels[2],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ bias=False,
+ )
+ self.in5_conv = nn.Conv2d(
+ in_channels=in_channels[3],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ bias=False,
+ )
+ self.p5_conv = nn.Conv2d(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ )
+ self.p4_conv = nn.Conv2d(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ )
+ self.p3_conv = nn.Conv2d(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ )
+ self.p2_conv = nn.Conv2d(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ )
+
+ if self.use_asf is True:
+ self.asf = ASFBlock(self.out_channels, self.out_channels // 4)
+
+ def forward(self, x):
+ c2, c3, c4, c5 = x
+
+ in5 = self.in5_conv(c5)
+ in4 = self.in4_conv(c4)
+ in3 = self.in3_conv(c3)
+ in2 = self.in2_conv(c2)
+
+ out4 = in4 + F.interpolate(
+ in5, scale_factor=2, mode='nearest', align_corners=None) # 1/16
+ out3 = in3 + F.interpolate(
+ out4, scale_factor=2, mode='nearest', align_corners=None) # 1/8
+ out2 = in2 + F.interpolate(
+ out3, scale_factor=2, mode='nearest', align_corners=None) # 1/4
+
+ p5 = self.p5_conv(in5)
+ p4 = self.p4_conv(out4)
+ p3 = self.p3_conv(out3)
+ p2 = self.p2_conv(out2)
+ p5 = F.interpolate(p5,
+ scale_factor=8,
+ mode='nearest',
+ align_corners=None)
+ p4 = F.interpolate(p4,
+ scale_factor=4,
+ mode='nearest',
+ align_corners=None)
+ p3 = F.interpolate(p3,
+ scale_factor=2,
+ mode='nearest',
+ align_corners=None)
+
+ fuse = torch.concat([p5, p4, p3, p2], dim=1)
+
+ if self.use_asf is True:
+ fuse = self.asf(fuse, [p5, p4, p3, p2])
+
+ return fuse
+
+
+class RSELayer(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):
+ super(RSELayer, self).__init__()
+ # weight_attr = paddle.nn.initializer.KaimingUniform()
+ self.out_channels = out_channels
+ self.in_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ kernel_size=kernel_size,
+ padding=int(kernel_size // 2),
+ # weight_attr=ParamAttr(initializer=weight_attr),
+ bias=False,
+ )
+ self.se_block = SEModule(self.out_channels)
+ self.shortcut = shortcut
+
+ def forward(self, ins):
+ x = self.in_conv(ins)
+ if self.shortcut:
+ out = x + self.se_block(x)
+ else:
+ out = self.se_block(x)
+ return out
+
+
+class RSEFPN(nn.Module):
+
+ def __init__(self, in_channels, out_channels, shortcut=True, **kwargs):
+ super(RSEFPN, self).__init__()
+ self.out_channels = out_channels
+ self.ins_conv = nn.ModuleList()
+ self.inp_conv = nn.ModuleList()
+ self.intracl = False
+ if 'intracl' in kwargs.keys() and kwargs['intracl'] is True:
+ self.intracl = kwargs['intracl']
+ self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
+ self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
+ self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
+ self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
+
+ for i in range(len(in_channels)):
+ self.ins_conv.append(
+ RSELayer(in_channels[i],
+ out_channels,
+ kernel_size=1,
+ shortcut=shortcut))
+ self.inp_conv.append(
+ RSELayer(out_channels,
+ out_channels // 4,
+ kernel_size=3,
+ shortcut=shortcut))
+
+ def forward(self, x):
+ c2, c3, c4, c5 = x
+
+ in5 = self.ins_conv[3](c5)
+ in4 = self.ins_conv[2](c4)
+ in3 = self.ins_conv[1](c3)
+ in2 = self.ins_conv[0](c2)
+
+ out4 = in4 + F.interpolate(
+ in5, scale_factor=2, mode='nearest', align_corners=None) # 1/16
+ out3 = in3 + F.interpolate(
+ out4, scale_factor=2, mode='nearest', align_corners=None) # 1/8
+ out2 = in2 + F.interpolate(
+ out3, scale_factor=2, mode='nearest', align_corners=None) # 1/4
+
+ p5 = self.inp_conv[3](in5)
+ p4 = self.inp_conv[2](out4)
+ p3 = self.inp_conv[1](out3)
+ p2 = self.inp_conv[0](out2)
+
+ if self.intracl is True:
+ p5 = self.incl4(p5)
+ p4 = self.incl3(p4)
+ p3 = self.incl2(p3)
+ p2 = self.incl1(p2)
+
+ p5 = F.interpolate(p5,
+ scale_factor=8,
+ mode='nearest',
+ align_corners=None)
+ p4 = F.interpolate(p4,
+ scale_factor=4,
+ mode='nearest',
+ align_corners=None)
+ p3 = F.interpolate(p3,
+ scale_factor=2,
+ mode='nearest',
+ align_corners=None)
+
+ fuse = torch.concat([p5, p4, p3, p2], dim=1)
+ return fuse
+
+
+class LKPAN(nn.Module):
+
+ def __init__(self, in_channels, out_channels, mode='large', **kwargs):
+ super(LKPAN, self).__init__()
+ self.out_channels = out_channels
+ # weight_attr = paddle.nn.initializer.KaimingUniform()
+
+ self.ins_conv = nn.ModuleList()
+ self.inp_conv = nn.ModuleList()
+ # pan head
+ self.pan_head_conv = nn.ModuleList()
+ self.pan_lat_conv = nn.ModuleList()
+
+ if mode.lower() == 'lite':
+ p_layer = DSConv
+ elif mode.lower() == 'large':
+ p_layer = nn.Conv2D
+ else:
+ raise ValueError(
+ "mode can only be one of ['lite', 'large'], but received {}".
+ format(mode))
+
+ for i in range(len(in_channels)):
+ self.ins_conv.append(
+ nn.Conv2d(
+ in_channels=in_channels[i],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ bias=False,
+ ))
+
+ self.inp_conv.append(
+ p_layer(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=9,
+ padding=4,
+ bias=False,
+ ))
+
+ if i > 0:
+ self.pan_head_conv.append(
+ nn.Conv2d(
+ in_channels=self.out_channels // 4,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ stride=2,
+ bias=False,
+ ))
+ self.pan_lat_conv.append(
+ p_layer(
+ in_channels=self.out_channels // 4,
+ out_channels=self.out_channels // 4,
+ kernel_size=9,
+ padding=4,
+ bias=False,
+ ))
+
+ self.intracl = False
+ if 'intracl' in kwargs.keys() and kwargs['intracl'] is True:
+ self.intracl = kwargs['intracl']
+ self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
+ self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
+ self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
+ self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
+
+ def forward(self, x):
+ c2, c3, c4, c5 = x
+
+ in5 = self.ins_conv[3](c5)
+ in4 = self.ins_conv[2](c4)
+ in3 = self.ins_conv[1](c3)
+ in2 = self.ins_conv[0](c2)
+
+ out4 = in4 + F.interpolate(
+ in5, scale_factor=2, mode='nearest', align_corners=None) # 1/16
+ out3 = in3 + F.interpolate(
+ out4, scale_factor=2, mode='nearest', align_corners=None) # 1/8
+ out2 = in2 + F.interpolate(
+ out3, scale_factor=2, mode='nearest', align_corners=None) # 1/4
+
+ f5 = self.inp_conv[3](in5)
+ f4 = self.inp_conv[2](out4)
+ f3 = self.inp_conv[1](out3)
+ f2 = self.inp_conv[0](out2)
+
+ pan3 = f3 + self.pan_head_conv[0](f2)
+ pan4 = f4 + self.pan_head_conv[1](pan3)
+ pan5 = f5 + self.pan_head_conv[2](pan4)
+
+ p2 = self.pan_lat_conv[0](f2)
+ p3 = self.pan_lat_conv[1](pan3)
+ p4 = self.pan_lat_conv[2](pan4)
+ p5 = self.pan_lat_conv[3](pan5)
+
+ if self.intracl is True:
+ p5 = self.incl4(p5)
+ p4 = self.incl3(p4)
+ p3 = self.incl2(p3)
+ p2 = self.incl1(p2)
+
+ p5 = F.interpolate(p5,
+ scale_factor=8,
+ mode='nearest',
+ align_corners=None)
+ p4 = F.interpolate(p4,
+ scale_factor=4,
+ mode='nearest',
+ align_corners=None)
+ p3 = F.interpolate(p3,
+ scale_factor=2,
+ mode='nearest',
+ align_corners=None)
+
+ fuse = torch.concat([p5, p4, p3, p2], dim=1)
+ return fuse
+
+
+class ASFBlock(nn.Module):
+ """
+ This code is refered from:
+ https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py
+ """
+
+ def __init__(self, in_channels, inter_channels, out_features_num=4):
+ """
+ Adaptive Scale Fusion (ASF) block of DBNet++
+ Args:
+ in_channels: the number of channels in the input data
+ inter_channels: the number of middle channels
+ out_features_num: the number of fused stages
+ """
+ super(ASFBlock, self).__init__()
+ # weight_attr = paddle.nn.initializer.KaimingUniform()
+ self.in_channels = in_channels
+ self.inter_channels = inter_channels
+ self.out_features_num = out_features_num
+ self.conv = nn.Conv2d(in_channels, inter_channels, 3, padding=1)
+
+ self.spatial_scale = nn.Sequential(
+ # Nx1xHxW
+ nn.Conv2d(
+ in_channels=1,
+ out_channels=1,
+ kernel_size=3,
+ bias=False,
+ padding=1,
+ ),
+ nn.ReLU(),
+ nn.Conv2d(
+ in_channels=1,
+ out_channels=1,
+ kernel_size=1,
+ bias=False,
+ ),
+ nn.Sigmoid(),
+ )
+
+ self.channel_scale = nn.Sequential(
+ nn.Conv2d(
+ in_channels=inter_channels,
+ out_channels=out_features_num,
+ kernel_size=1,
+ bias=False,
+ ),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, fuse_features, features_list):
+ fuse_features = self.conv(fuse_features)
+ spatial_x = torch.mean(fuse_features, dim=1, keepdim=True)
+ attention_scores = self.spatial_scale(spatial_x) + fuse_features
+ attention_scores = self.channel_scale(attention_scores)
+ assert len(features_list) == self.out_features_num
+
+ out_list = []
+ for i in range(self.out_features_num):
+ out_list.append(attention_scores[:, i:i + 1] * features_list[i])
+ return torch.concat(out_list, dim=1)
diff --git a/opendet/postprocess/__init__.py b/opendet/postprocess/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e8cbc09f9d04f5ff0921576be24c838e827a76d
--- /dev/null
+++ b/opendet/postprocess/__init__.py
@@ -0,0 +1,18 @@
+import copy
+
+__all__ = ['build_post_process']
+
+from .db_postprocess import DBPostProcess
+
+support_dict = ['DBPostProcess']
+
+
+def build_post_process(config, global_config=None):
+ config = copy.deepcopy(config)
+ module_name = config.pop('name')
+ if global_config is not None:
+ config.update(global_config)
+ assert module_name in support_dict, Exception(
+ 'det post process only support {}'.format(support_dict))
+ module_class = eval(module_name)(**config)
+ return module_class
diff --git a/opendet/postprocess/__pycache__/__init__.cpython-38.pyc b/opendet/postprocess/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b11f19634c5c2f0c8ef499b629d608b85bd3ef6c
Binary files /dev/null and b/opendet/postprocess/__pycache__/__init__.cpython-38.pyc differ
diff --git a/opendet/postprocess/__pycache__/db_postprocess.cpython-38.pyc b/opendet/postprocess/__pycache__/db_postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dbfdc4fad01406ee59c87924d050dada0f12fd9a
Binary files /dev/null and b/opendet/postprocess/__pycache__/db_postprocess.cpython-38.pyc differ
diff --git a/opendet/postprocess/db_postprocess.py b/opendet/postprocess/db_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd6b199f959e6d56e3cd050fdbb804b21536720c
--- /dev/null
+++ b/opendet/postprocess/db_postprocess.py
@@ -0,0 +1,273 @@
+import numpy as np
+import cv2
+import torch
+from shapely.geometry import Polygon
+import pyclipper
+"""
+This code is refered from:
+https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
+"""
+
+
+class DBPostProcess(object):
+ """
+ The post process for Differentiable Binarization (DB).
+ """
+
+ def __init__(
+ self,
+ thresh=0.3,
+ box_thresh=0.7,
+ max_candidates=1000,
+ unclip_ratio=2.0,
+ use_dilation=False,
+ score_mode='fast',
+ box_type='quad',
+ **kwargs,
+ ):
+ self.thresh = thresh
+ self.box_thresh = box_thresh
+ self.max_candidates = max_candidates
+ self.unclip_ratio = unclip_ratio
+ self.min_size = 3
+ self.score_mode = score_mode
+ self.box_type = box_type
+ assert score_mode in [
+ 'slow',
+ 'fast',
+ ], 'Score mode must be in [slow, fast] but got: {}'.format(score_mode)
+
+ self.dilation_kernel = None if not use_dilation else np.array([[1, 1],
+ [1, 1]])
+
+ def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
+ """
+ _bitmap: single map with shape (1, H, W),
+ whose values are binarized as {0, 1}
+ """
+
+ bitmap = _bitmap
+ height, width = bitmap.shape
+
+ boxes = []
+ scores = []
+
+ contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
+ cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
+
+ for contour in contours[:self.max_candidates]:
+ epsilon = 0.002 * cv2.arcLength(contour, True)
+ approx = cv2.approxPolyDP(contour, epsilon, True)
+ points = approx.reshape((-1, 2))
+ if points.shape[0] < 4:
+ continue
+
+ score = self.box_score_fast(pred, points.reshape(-1, 2))
+ if self.box_thresh > score:
+ continue
+
+ if points.shape[0] > 2:
+ box = self.unclip(points, self.unclip_ratio)
+ if len(box) > 1:
+ continue
+ else:
+ continue
+ box = np.array(box).reshape(-1, 2)
+ if len(box) == 0:
+ continue
+
+ _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
+ if sside < self.min_size + 2:
+ continue
+
+ box = np.array(box)
+ box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0,
+ dest_width)
+ box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0,
+ dest_height)
+ boxes.append(box.tolist())
+ scores.append(score)
+ return boxes, scores
+
+ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
+ """
+ _bitmap: single map with shape (1, H, W),
+ whose values are binarized as {0, 1}
+ """
+
+ bitmap = _bitmap
+ height, width = bitmap.shape
+
+ outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
+ cv2.CHAIN_APPROX_SIMPLE)
+ if len(outs) == 3:
+ img, contours, _ = outs[0], outs[1], outs[2]
+ elif len(outs) == 2:
+ contours, _ = outs[0], outs[1]
+
+ num_contours = min(len(contours), self.max_candidates)
+
+ boxes = []
+ scores = []
+ for index in range(num_contours):
+ contour = contours[index]
+ points, sside = self.get_mini_boxes(contour)
+ if sside < self.min_size:
+ continue
+ points = np.array(points)
+ if self.score_mode == 'fast':
+ score = self.box_score_fast(pred, points.reshape(-1, 2))
+ else:
+ score = self.box_score_slow(pred, contour)
+ if self.box_thresh > score:
+ continue
+
+ box = self.unclip(points, self.unclip_ratio)
+ if len(box) > 1:
+ continue
+ box = np.array(box).reshape(-1, 1, 2)
+ box, sside = self.get_mini_boxes(box)
+ if sside < self.min_size + 2:
+ continue
+ box = np.array(box)
+
+ box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0,
+ dest_width)
+ box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0,
+ dest_height)
+ boxes.append(box.astype('int32'))
+ scores.append(score)
+ return np.array(boxes, dtype='int32'), scores
+
+ def unclip(self, box, unclip_ratio):
+ poly = Polygon(box)
+ distance = poly.area * unclip_ratio / poly.length
+ offset = pyclipper.PyclipperOffset()
+ offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+ expanded = offset.Execute(distance)
+ return expanded
+
+ def get_mini_boxes(self, contour):
+ bounding_box = cv2.minAreaRect(contour)
+ points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+ index_1, index_2, index_3, index_4 = 0, 1, 2, 3
+ if points[1][1] > points[0][1]:
+ index_1 = 0
+ index_4 = 1
+ else:
+ index_1 = 1
+ index_4 = 0
+ if points[3][1] > points[2][1]:
+ index_2 = 2
+ index_3 = 3
+ else:
+ index_2 = 3
+ index_3 = 2
+
+ box = [
+ points[index_1], points[index_2], points[index_3], points[index_4]
+ ]
+ return box, min(bounding_box[1])
+
+ def box_score_fast(self, bitmap, _box):
+ """
+ box_score_fast: use bbox mean score as the mean score
+ """
+ h, w = bitmap.shape[:2]
+ box = _box.copy()
+ xmin = np.clip(np.floor(box[:, 0].min()).astype('int32'), 0, w - 1)
+ xmax = np.clip(np.ceil(box[:, 0].max()).astype('int32'), 0, w - 1)
+ ymin = np.clip(np.floor(box[:, 1].min()).astype('int32'), 0, h - 1)
+ ymax = np.clip(np.ceil(box[:, 1].max()).astype('int32'), 0, h - 1)
+
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
+ box[:, 0] = box[:, 0] - xmin
+ box[:, 1] = box[:, 1] - ymin
+ cv2.fillPoly(mask, box.reshape(1, -1, 2).astype('int32'), 1)
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+
+ def box_score_slow(self, bitmap, contour):
+ """
+ box_score_slow: use polyon mean score as the mean score
+ """
+ h, w = bitmap.shape[:2]
+ contour = contour.copy()
+ contour = np.reshape(contour, (-1, 2))
+
+ xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
+ xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
+ ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
+ ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
+
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
+
+ contour[:, 0] = contour[:, 0] - xmin
+ contour[:, 1] = contour[:, 1] - ymin
+
+ cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype('int32'), 1)
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+
+ def __call__(self, outs_dict, shape_list):
+ pred = outs_dict['maps']
+ if isinstance(pred, torch.Tensor):
+ pred = pred.detach().cpu().numpy()
+ pred = pred[:, 0, :, :]
+ segmentation = pred > self.thresh
+
+ boxes_batch = []
+ for batch_index in range(pred.shape[0]):
+ src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
+ if self.dilation_kernel is not None:
+ mask = cv2.dilate(
+ np.array(segmentation[batch_index]).astype(np.uint8),
+ self.dilation_kernel,
+ )
+ else:
+ mask = segmentation[batch_index]
+ if self.box_type == 'poly':
+ boxes, scores = self.polygons_from_bitmap(
+ pred[batch_index], mask, src_w, src_h)
+ elif self.box_type == 'quad':
+ boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
+ src_w, src_h)
+ else:
+ raise ValueError(
+ "box_type can only be one of ['quad', 'poly']")
+
+ boxes_batch.append({'points': boxes})
+ return boxes_batch
+
+
+class DistillationDBPostProcess(object):
+
+ def __init__(
+ self,
+ model_name=['student'],
+ key=None,
+ thresh=0.3,
+ box_thresh=0.6,
+ max_candidates=1000,
+ unclip_ratio=1.5,
+ use_dilation=False,
+ score_mode='fast',
+ box_type='quad',
+ **kwargs,
+ ):
+ self.model_name = model_name
+ self.key = key
+ self.post_process = DBPostProcess(
+ thresh=thresh,
+ box_thresh=box_thresh,
+ max_candidates=max_candidates,
+ unclip_ratio=unclip_ratio,
+ use_dilation=use_dilation,
+ score_mode=score_mode,
+ box_type=box_type,
+ )
+
+ def __call__(self, predicts, shape_list):
+ results = {}
+ for k in self.model_name:
+ results[k] = self.post_process(predicts[k], shape_list=shape_list)
+ return results
diff --git a/opendet/preprocess/__init__.py b/opendet/preprocess/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..851cb650e0230784c7235fab89fe5b80867c5c57
--- /dev/null
+++ b/opendet/preprocess/__init__.py
@@ -0,0 +1,154 @@
+import io
+
+import cv2
+import numpy as np
+from PIL import Image
+
+from .db_resize_for_test import DetResizeForTest
+
+
+class NormalizeImage(object):
+ """normalize image such as substract mean, divide std"""
+
+ def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
+ if isinstance(scale, str):
+ scale = eval(scale)
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
+ std = std if std is not None else [0.229, 0.224, 0.225]
+
+ shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
+ self.mean = np.array(mean).reshape(shape).astype('float32')
+ self.std = np.array(std).reshape(shape).astype('float32')
+
+ def __call__(self, data):
+ img = data['image']
+ from PIL import Image
+
+ if isinstance(img, Image.Image):
+ img = np.array(img)
+ assert isinstance(img,
+ np.ndarray), "invalid input 'img' in NormalizeImage"
+ data['image'] = (img.astype('float32') * self.scale -
+ self.mean) / self.std
+ return data
+
+
+class ToCHWImage(object):
+ """convert hwc image to chw image"""
+
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, data):
+ img = data['image']
+ from PIL import Image
+
+ if isinstance(img, Image.Image):
+ img = np.array(img)
+ data['image'] = img.transpose((2, 0, 1))
+ return data
+
+
+class KeepKeys(object):
+
+ def __init__(self, keep_keys, **kwargs):
+ self.keep_keys = keep_keys
+
+ def __call__(self, data):
+ data_list = []
+ for key in self.keep_keys:
+ data_list.append(data[key])
+ return data_list
+
+
+def transform(data, ops=None):
+ """transform."""
+ if ops is None:
+ ops = []
+ for op in ops:
+ data = op(data)
+ if data is None:
+ return None
+ return data
+
+
+class DecodeImage(object):
+ """decode image."""
+
+ def __init__(self,
+ img_mode='RGB',
+ channel_first=False,
+ ignore_orientation=False,
+ **kwargs):
+ self.img_mode = img_mode
+ self.channel_first = channel_first
+ self.ignore_orientation = ignore_orientation
+
+ def __call__(self, data):
+ img = data['image']
+
+ assert type(img) is bytes and len(
+ img) > 0, "invalid input 'img' in DecodeImage"
+ img = np.frombuffer(img, dtype='uint8')
+ if self.ignore_orientation:
+ img = cv2.imdecode(
+ img, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR)
+ else:
+ img = cv2.imdecode(img, 1)
+ if img is None:
+ return None
+ if self.img_mode == 'GRAY':
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ elif self.img_mode == 'RGB':
+ assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
+ img.shape)
+ img = img[:, :, ::-1]
+
+ if self.channel_first:
+ img = img.transpose((2, 0, 1))
+
+ data['image'] = img
+ return data
+
+
+class DecodeImagePIL(object):
+ """decode image."""
+
+ def __init__(self, img_mode='RGB', **kwargs):
+ self.img_mode = img_mode
+
+ def __call__(self, data):
+ img = data['image']
+ assert type(img) is bytes and len(
+ img) > 0, "invalid input 'img' in DecodeImage"
+ img = data['image']
+ buf = io.BytesIO(img)
+ img = Image.open(buf).convert('RGB')
+ if self.img_mode == 'Gray':
+ img = img.convert('L')
+ elif self.img_mode == 'BGR':
+ img = np.array(img)[:, :, ::-1] # 将图片转为numpy格式,并将最后一维通道倒序
+ img = Image.fromarray(np.uint8(img))
+ data['image'] = img
+ return data
+
+
+def create_operators(op_param_list, global_config=None):
+ """create operators based on the config.
+
+ Args:
+ params(list): a dict list, used to create some operators
+ """
+ assert isinstance(op_param_list, list), 'operator config should be a list'
+ ops = []
+ for operator in op_param_list:
+ assert isinstance(operator,
+ dict) and len(operator) == 1, 'yaml format error'
+ op_name = list(operator)[0]
+ param = {} if operator[op_name] is None else operator[op_name]
+ if global_config is not None:
+ param.update(global_config)
+ op = eval(op_name)(**param)
+ ops.append(op)
+ return ops
diff --git a/opendet/preprocess/__pycache__/__init__.cpython-38.pyc b/opendet/preprocess/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1be3a80b88058657ab054517a5315ef951462e6
Binary files /dev/null and b/opendet/preprocess/__pycache__/__init__.cpython-38.pyc differ
diff --git a/opendet/preprocess/__pycache__/db_resize_for_test.cpython-38.pyc b/opendet/preprocess/__pycache__/db_resize_for_test.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..799f9127a95e2da6606245b4e320a11b6646a7e7
Binary files /dev/null and b/opendet/preprocess/__pycache__/db_resize_for_test.cpython-38.pyc differ
diff --git a/opendet/preprocess/crop_resize.py b/opendet/preprocess/crop_resize.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa67a472aeb3331bbcc4c147412919427cdfbde5
--- /dev/null
+++ b/opendet/preprocess/crop_resize.py
@@ -0,0 +1,121 @@
+import cv2
+
+
+def padding_image(img, size=(640, 640)):
+ """
+ Padding an image using OpenCV:
+ - If the image is smaller than the target size, pad it to 640x640.
+ - If the image is larger than the target size, split it into multiple 640x640 images and record positions.
+
+ :param image_path: Path to the input image.
+ :param output_dir: Directory to save the output images.
+ :param size: The target size for padding or splitting (default 640x640).
+ :return: List of tuples containing the coordinates of the top-left corner of each cropped 640x640 image.
+ """
+
+ img_height, img_width = img.shape[:2]
+ target_width, target_height = size
+
+ # If image is smaller than target size, pad the image to 640x640
+
+ # Calculate padding amounts (top, bottom, left, right)
+ pad_top = 0
+ pad_bottom = target_height - img_height
+ pad_left = 0
+ pad_right = target_width - img_width
+
+ # Pad the image (white padding, border type: constant)
+ padded_img = cv2.copyMakeBorder(img,
+ pad_top,
+ pad_bottom,
+ pad_left,
+ pad_right,
+ cv2.BORDER_CONSTANT,
+ value=[0, 0, 0])
+
+ # Return the padded area positions (top-left and bottom-right coordinates of the original image)
+ return padded_img
+
+
+class CropResize(object):
+
+ def __init__(self, size=(640, 640), interpolation=cv2.INTER_LINEAR):
+ self.size = size
+ self.interpolation = interpolation
+
+ def __call__(self, data):
+ """
+ Resize an image using OpenCV:
+ - If the image is smaller than the target size, pad it to 640x640.
+ - If the image is larger than the target size, split it into multiple 640x640 images and record positions.
+
+ :param image_path: Path to the input image.
+ :param output_dir: Directory to save the output images.
+ :param size: The target size for padding or splitting (default 640x640).
+ :return: List of tuples containing the coordinates of the top-left corner of each cropped 640x640 image.
+ """
+ img = data['image']
+ img_height, img_width = img.shape[:2]
+ target_width, target_height = self.size
+
+ # If image is smaller than target size, pad the image to 640x640
+ if img_width <= target_width and img_height <= target_height:
+ # Calculate padding amounts (top, bottom, left, right)
+ if img_width == target_width and img_height == target_height:
+ return [img], [[0, 0, img_width, img_height]]
+ padded_img = padding_image(img, self.size)
+
+ # Return the padded area positions (top-left and bottom-right coordinates of the original image)
+ return [padded_img], [[0, 0, img_width, img_height]]
+
+ if img_width < target_width:
+ img = cv2.copyMakeBorder(img,
+ 0,
+ 0,
+ 0,
+ target_width - img_width,
+ cv2.BORDER_CONSTANT,
+ value=[0, 0, 0])
+
+ if img_height < target_height:
+ img = cv2.copyMakeBorder(img,
+ 0,
+ target_height - img_height,
+ 0,
+ 0,
+ cv2.BORDER_CONSTANT,
+ value=[0, 0, 0])
+ # raise ValueError("Image dimensions must be greater than or equal to target size")
+
+ img_height, img_width = img.shape[:2]
+ # If image is larger than or equal to target size, crop it into 640x640 tiles
+ crop_positions = []
+ count = 0
+ cropped_img_list = []
+ for top in range(0, img_height - target_height // 2,
+ target_height // 2):
+ for left in range(0, img_width - target_height // 2,
+ target_width // 2):
+ # Calculate the bottom and right boundaries for the crop
+ right = min(left + target_width, img_width)
+ bottom = min(top + target_height, img_height)
+ if right > img_width:
+ right = img_width
+ left = max(0, right - target_width)
+ if bottom > img_height:
+ bottom = img_height
+ top = max(0, bottom - target_height)
+ # Crop the image
+ cropped_img = img[top:bottom, left:right]
+ if bottom - top < target_height or right - left < target_width:
+ cropped_img = padding_image(cropped_img, self.size)
+
+ count += 1
+ cropped_img_list.append(cropped_img)
+
+ # Record the position of the cropped image
+ crop_positions.append([left, top, right, bottom])
+
+ # print(f"Images cropped and saved at {output_dir}.")
+
+ return cropped_img_list, crop_positions
diff --git a/opendet/preprocess/db_resize_for_test.py b/opendet/preprocess/db_resize_for_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbd95f085aee971946cbe66bf4ac84b7854bd6ee
--- /dev/null
+++ b/opendet/preprocess/db_resize_for_test.py
@@ -0,0 +1,135 @@
+import math
+import sys
+import cv2
+import numpy as np
+
+
+class DetResizeForTest(object):
+
+ def __init__(self, **kwargs):
+ super(DetResizeForTest, self).__init__()
+ self.resize_type = 0
+ self.keep_ratio = False
+ if 'image_shape' in kwargs:
+ self.image_shape = kwargs['image_shape']
+ self.resize_type = 1
+ if 'keep_ratio' in kwargs:
+ self.keep_ratio = kwargs['keep_ratio']
+ elif 'limit_side_len' in kwargs:
+ self.limit_side_len = kwargs['limit_side_len']
+ self.limit_type = kwargs.get('limit_type', 'min')
+ elif 'resize_long' in kwargs:
+ self.resize_type = 2
+ self.resize_long = kwargs.get('resize_long', 960)
+ else:
+ self.limit_side_len = 736
+ self.limit_type = 'min'
+
+ def __call__(self, data):
+ img = data['image']
+ src_h, src_w, _ = img.shape
+ if sum([src_h, src_w]) < 64:
+ img = self.image_padding(img)
+
+ if self.resize_type == 0:
+ # img, shape = self.resize_image_type0(img)
+ img, [ratio_h, ratio_w] = self.resize_image_type0(img)
+ elif self.resize_type == 2:
+ img, [ratio_h, ratio_w] = self.resize_image_type2(img)
+ else:
+ # img, shape = self.resize_image_type1(img)
+ img, [ratio_h, ratio_w] = self.resize_image_type1(img)
+ data['image'] = img
+ data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
+ return data
+
+ def image_padding(self, im, value=0):
+ h, w, c = im.shape
+ im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
+ im_pad[:h, :w, :] = im
+ return im_pad
+
+ def resize_image_type1(self, img):
+ resize_h, resize_w = self.image_shape
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
+ if self.keep_ratio is True:
+ resize_w = ori_w * resize_h / ori_h
+ N = math.ceil(resize_w / 32)
+ resize_w = N * 32
+ ratio_h = float(resize_h) / ori_h
+ ratio_w = float(resize_w) / ori_w
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
+ # return img, np.array([ori_h, ori_w])
+ return img, [ratio_h, ratio_w]
+
+ def resize_image_type0(self, img):
+ """
+ resize image to a size multiple of 32 which is required by the network
+ args:
+ img(array): array with shape [h, w, c]
+ return(tuple):
+ img, (ratio_h, ratio_w)
+ """
+ limit_side_len = self.limit_side_len
+ h, w, c = img.shape
+
+ # limit the max side
+ if self.limit_type == 'max':
+ if max(h, w) > limit_side_len:
+ if h > w:
+ ratio = float(limit_side_len) / h
+ else:
+ ratio = float(limit_side_len) / w
+ else:
+ ratio = 1.0
+ elif self.limit_type == 'min':
+ if min(h, w) < limit_side_len:
+ if h < w:
+ ratio = float(limit_side_len) / h
+ else:
+ ratio = float(limit_side_len) / w
+ else:
+ ratio = 1.0
+ elif self.limit_type == 'resize_long':
+ ratio = float(limit_side_len) / max(h, w)
+ else:
+ raise Exception('not support limit type, image ')
+ resize_h = int(h * ratio)
+ resize_w = int(w * ratio)
+
+ resize_h = max(int(round(resize_h / 32) * 32), 32)
+ resize_w = max(int(round(resize_w / 32) * 32), 32)
+
+ try:
+ if int(resize_w) <= 0 or int(resize_h) <= 0:
+ return None, (None, None)
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
+ except:
+ print(img.shape, resize_w, resize_h)
+ sys.exit(0)
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+ return img, [ratio_h, ratio_w]
+
+ def resize_image_type2(self, img):
+ h, w, _ = img.shape
+
+ resize_w = w
+ resize_h = h
+
+ if resize_h > resize_w:
+ ratio = float(self.resize_long) / resize_h
+ else:
+ ratio = float(self.resize_long) / resize_w
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+
+ return img, [ratio_h, ratio_w]
diff --git a/openrec/losses/__init__.py b/openrec/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7af475d04f068aa8b7785496cfd047801090c4e2
--- /dev/null
+++ b/openrec/losses/__init__.py
@@ -0,0 +1,62 @@
+import copy
+
+from torch import nn
+
+from .abinet_loss import ABINetLoss
+from .ar_loss import ARLoss
+from .cdistnet_loss import CDistNetLoss
+from .ce_loss import CELoss
+from .cppd_loss import CPPDLoss
+from .ctc_loss import CTCLoss
+from .igtr_loss import IGTRLoss
+from .lister_loss import LISTERLoss
+from .lpv_loss import LPVLoss
+from .mgp_loss import MGPLoss
+from .parseq_loss import PARSeqLoss
+from .robustscanner_loss import RobustScannerLoss
+from .smtr_loss import SMTRLoss
+from .srn_loss import SRNLoss
+from .visionlan_loss import VisionLANLoss
+from .cam_loss import CAMLoss
+from .seed_loss import SEEDLoss
+
+support_dict = [
+ 'CTCLoss', 'ARLoss', 'CELoss', 'CPPDLoss', 'ABINetLoss', 'CDistNetLoss',
+ 'VisionLANLoss', 'PARSeqLoss', 'IGTRLoss', 'SMTRLoss', 'LPVLoss',
+ 'RobustScannerLoss', 'SRNLoss', 'LISTERLoss', 'GTCLoss', 'MGPLoss',
+ 'CAMLoss', 'SEEDLoss'
+]
+
+
+def build_loss(config):
+ config = copy.deepcopy(config)
+ module_name = config.pop('name')
+ assert module_name in support_dict, Exception(
+ 'loss only support {}'.format(support_dict))
+ module_class = eval(module_name)(**config)
+ return module_class
+
+
+class GTCLoss(nn.Module):
+
+ def __init__(self,
+ gtc_loss,
+ gtc_weight=1.0,
+ ctc_weight=1.0,
+ zero_infinity=True,
+ **kwargs):
+ super(GTCLoss, self).__init__()
+ self.ctc_loss = CTCLoss(zero_infinity=zero_infinity)
+ self.gtc_loss = build_loss(gtc_loss)
+ self.gtc_weight = gtc_weight
+ self.ctc_weight = ctc_weight
+
+ def forward(self, predicts, batch):
+ ctc_loss = self.ctc_loss(predicts['ctc_pred'],
+ [None] + batch[-2:])['loss']
+ gtc_loss = self.gtc_loss(predicts['gtc_pred'], batch[:-2])['loss']
+ return {
+ 'loss': self.ctc_weight * ctc_loss + self.gtc_weight * gtc_loss,
+ 'ctc_loss': ctc_loss,
+ 'gtc_loss': gtc_loss
+ }
diff --git a/openrec/losses/__pycache__/__init__.cpython-38.pyc b/openrec/losses/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d67679953a8a4d46f140eced5d97b0091c5c2658
Binary files /dev/null and b/openrec/losses/__pycache__/__init__.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/abinet_loss.cpython-38.pyc b/openrec/losses/__pycache__/abinet_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94d719c78d8060df5fedfc431bac29344499b1c8
Binary files /dev/null and b/openrec/losses/__pycache__/abinet_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/ar_loss.cpython-38.pyc b/openrec/losses/__pycache__/ar_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc22f9f9c87ea77948f14e9458d5bd77b09e8705
Binary files /dev/null and b/openrec/losses/__pycache__/ar_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/cam_loss.cpython-38.pyc b/openrec/losses/__pycache__/cam_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b481fb70d931b55750efd1a3cfc21f3b7ce2f38
Binary files /dev/null and b/openrec/losses/__pycache__/cam_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/cdistnet_loss.cpython-38.pyc b/openrec/losses/__pycache__/cdistnet_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a13299fcc55fa13448922a42919d488c5ac5827b
Binary files /dev/null and b/openrec/losses/__pycache__/cdistnet_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/ce_loss.cpython-38.pyc b/openrec/losses/__pycache__/ce_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82457ae022e1c299bfbce8992325821d5a76dffd
Binary files /dev/null and b/openrec/losses/__pycache__/ce_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/cppd_loss.cpython-38.pyc b/openrec/losses/__pycache__/cppd_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2314c408d544d07285fe68e8e36aa2917b58801
Binary files /dev/null and b/openrec/losses/__pycache__/cppd_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/ctc_loss.cpython-38.pyc b/openrec/losses/__pycache__/ctc_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8815e664e179a0342fb3ef78c9e7267916bcddd3
Binary files /dev/null and b/openrec/losses/__pycache__/ctc_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/igtr_loss.cpython-38.pyc b/openrec/losses/__pycache__/igtr_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d778713f5614637bb2f8075fdd419e3c7773417
Binary files /dev/null and b/openrec/losses/__pycache__/igtr_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/lister_loss.cpython-38.pyc b/openrec/losses/__pycache__/lister_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..84e97027b3e0104301889146af6c0f869ac0456b
Binary files /dev/null and b/openrec/losses/__pycache__/lister_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/lpv_loss.cpython-38.pyc b/openrec/losses/__pycache__/lpv_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa5fc5e3ff92bece59e388a1404c69a7e8874e5b
Binary files /dev/null and b/openrec/losses/__pycache__/lpv_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/mgp_loss.cpython-38.pyc b/openrec/losses/__pycache__/mgp_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..173ce497ff43d2692368e83dff94a06142fc6e4d
Binary files /dev/null and b/openrec/losses/__pycache__/mgp_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/parseq_loss.cpython-38.pyc b/openrec/losses/__pycache__/parseq_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e4aab40aeb221915fa6356bb9779299b4b9bf79
Binary files /dev/null and b/openrec/losses/__pycache__/parseq_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/robustscanner_loss.cpython-38.pyc b/openrec/losses/__pycache__/robustscanner_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8f49117b0fb16fb87b1097c1c6233ece98fdbe6
Binary files /dev/null and b/openrec/losses/__pycache__/robustscanner_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/seed_loss.cpython-38.pyc b/openrec/losses/__pycache__/seed_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3c3527507799adbe4632bccb195053e7686230c
Binary files /dev/null and b/openrec/losses/__pycache__/seed_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/smtr_loss.cpython-38.pyc b/openrec/losses/__pycache__/smtr_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..abc890dd06e27b0afc5b6c1a9c4431b3d69f521d
Binary files /dev/null and b/openrec/losses/__pycache__/smtr_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/srn_loss.cpython-38.pyc b/openrec/losses/__pycache__/srn_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15ac03669a193d1edcd4f57dd6c3ca3f813896a9
Binary files /dev/null and b/openrec/losses/__pycache__/srn_loss.cpython-38.pyc differ
diff --git a/openrec/losses/__pycache__/visionlan_loss.cpython-38.pyc b/openrec/losses/__pycache__/visionlan_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38a17e4629f493e147f477237cfeda20f7e833d9
Binary files /dev/null and b/openrec/losses/__pycache__/visionlan_loss.cpython-38.pyc differ
diff --git a/openrec/losses/abinet_loss.py b/openrec/losses/abinet_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bc8fe183eed97089de4850a9e5ecf26c0c15d6a
--- /dev/null
+++ b/openrec/losses/abinet_loss.py
@@ -0,0 +1,42 @@
+import torch
+from torch import nn
+
+
+class ABINetLoss(nn.Module):
+
+ def __init__(self,
+ smoothing=False,
+ ignore_index=100,
+ align_weight=1.0,
+ **kwargs):
+ super(ABINetLoss, self).__init__()
+ if ignore_index >= 0:
+ self.loss_func = nn.CrossEntropyLoss(reduction='mean',
+ ignore_index=ignore_index)
+ else:
+ self.loss_func = nn.CrossEntropyLoss(reduction='mean')
+ self.smoothing = smoothing
+ self.align_weight = align_weight
+
+ def forward(self, pred, batch):
+ loss = {}
+ loss_sum = []
+ for name, logits in pred.items():
+ if isinstance(logits, list):
+ logit_num = len(logits)
+ if logit_num > 0:
+ all_tgt = torch.cat([batch[1]] * logit_num, 0)
+ all_logits = torch.cat(logits, 0)
+ flt_logtis = all_logits.reshape([-1, all_logits.shape[2]])
+ flt_tgt = all_tgt.reshape([-1])
+ else:
+ continue
+ else:
+ flt_logtis = logits.reshape([-1, logits.shape[2]])
+ flt_tgt = batch[1].reshape([-1])
+
+ loss[name + '_loss'] = self.loss_func(flt_logtis, flt_tgt) * (
+ self.align_weight if name == 'align' else 1.0)
+ loss_sum.append(loss[name + '_loss'])
+ loss['loss'] = sum(loss_sum)
+ return loss
diff --git a/openrec/losses/ar_loss.py b/openrec/losses/ar_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a24447e90b613a1e59fd5c77b1310601b4ecbc2
--- /dev/null
+++ b/openrec/losses/ar_loss.py
@@ -0,0 +1,23 @@
+import torch.nn.functional as F
+from torch import nn
+
+
+class ARLoss(nn.Module):
+
+ def __init__(self, label_smoothing=0.1, ignore_index=0, **kwargs):
+ super(ARLoss, self).__init__()
+ self.label_smoothing = label_smoothing
+
+ def forward(self, pred, batch):
+ max_len = batch[2].max()
+ tgt = batch[1][:, 1:2 + max_len]
+ pred = pred.flatten(0, 1)
+ tgt = tgt.reshape([-1])
+ loss = F.cross_entropy(
+ pred,
+ tgt,
+ reduction='mean',
+ label_smoothing=self.label_smoothing,
+ ignore_index=pred.shape[1] + 1,
+ ) # self.loss_func(pred, tgt)
+ return {'loss': loss}
diff --git a/openrec/losses/cam_loss.py b/openrec/losses/cam_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..09ee169a19fb3d361b3f0c82476d96f822fe1737
--- /dev/null
+++ b/openrec/losses/cam_loss.py
@@ -0,0 +1,48 @@
+import torch
+import torch.nn.functional as F
+
+from .ar_loss import ARLoss
+
+
+def BanlanceMultiClassCrossEntropyLoss(x_o, x_t):
+ # [B, num_cls, H, W]
+ B, num_cls, H, W = x_o.shape
+ x_o = x_o.reshape(B, num_cls, H * W).permute(0, 2, 1)
+ # [B, H, W, num_cls]
+ # generate gt
+ x_t[x_t > 0.5] = 1
+ x_t[x_t <= 0.5] = 0
+ fg_x_t = x_t.sum(-1) # [B, H, W]
+ x_t = x_t.argmax(-1) # [B, H, W]
+ x_t[fg_x_t == 0] = num_cls - 1 # background
+ x_t = x_t.reshape(B, H * W)
+ # loss
+ weight = torch.ones((B, num_cls)).type_as(x_o) # the weight of bg is 1.
+ num_bg = (x_t == (num_cls - 1)).sum(-1) # [B]
+ weight[:, :-1] = (num_bg / (H * W - num_bg + 1e-5)).unsqueeze(-1).expand(
+ -1, num_cls - 1)
+ logit = F.log_softmax(x_o, dim=-1) # [B, H*W, num_cls]
+ logit = logit * weight.unsqueeze(1)
+ loss = -logit.gather(2, x_t.unsqueeze(-1).long())
+ return loss.mean()
+
+
+class CAMLoss(ARLoss):
+
+ def __init__(self, label_smoothing=0.1, loss_weight_binary=1.5, **kwargs):
+ super(CAMLoss, self).__init__(label_smoothing=label_smoothing)
+ self.label_smoothing = label_smoothing
+ self.loss_weight_binary = loss_weight_binary
+
+ def forward(self, pred, batch):
+ binary_mask = batch[-1]
+ rec_loss = super().forward(pred['rec_output'], batch[:-1])['loss']
+ output = pred
+ loss_binary = self.loss_weight_binary * BanlanceMultiClassCrossEntropyLoss(
+ output['pred_binary'], binary_mask)
+
+ return {
+ 'loss': rec_loss + loss_binary,
+ 'rec_loss': rec_loss,
+ 'loss_binary': loss_binary
+ }
diff --git a/openrec/losses/cdistnet_loss.py b/openrec/losses/cdistnet_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..d33b97356844e95c0ec859f299966c4434cf64d8
--- /dev/null
+++ b/openrec/losses/cdistnet_loss.py
@@ -0,0 +1,34 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class CDistNetLoss(nn.Module):
+
+ def __init__(self, smoothing=True, ignore_index=0, **kwargs):
+ super(CDistNetLoss, self).__init__()
+ if ignore_index >= 0 and not smoothing:
+ self.loss_func = nn.CrossEntropyLoss(reduction='mean',
+ ignore_index=ignore_index)
+ self.smoothing = smoothing
+
+ def forward(self, pred, batch):
+ pred = pred['res']
+ tgt = batch[1][:, 1:]
+ pred = pred.reshape([-1, pred.shape[2]])
+ tgt = tgt.reshape([-1])
+ if self.smoothing:
+ eps = 0.1
+ n_class = pred.shape[1]
+ one_hot = F.one_hot(tgt.long(), num_classes=pred.shape[1])
+ torch.set_printoptions(profile='full')
+ one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
+ log_prb = F.log_softmax(pred, dim=1)
+ non_pad_mask = torch.not_equal(
+ tgt, torch.zeros(tgt.shape, dtype=tgt.dtype,
+ device=tgt.device))
+ loss = -(one_hot * log_prb).sum(dim=1)
+ loss = loss.masked_select(non_pad_mask).mean()
+ else:
+ loss = self.loss_func(pred, tgt)
+ return {'loss': loss}
diff --git a/openrec/losses/ce_loss.py b/openrec/losses/ce_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..f56f23c17566694971c58da8a70b14296f4d08d7
--- /dev/null
+++ b/openrec/losses/ce_loss.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class CELoss(nn.Module):
+
+ def __init__(self,
+ smoothing=False,
+ with_all=False,
+ ignore_index=-1,
+ **kwargs):
+ super(CELoss, self).__init__()
+ if ignore_index >= 0:
+ self.loss_func = nn.CrossEntropyLoss(reduction='mean',
+ ignore_index=ignore_index)
+ else:
+ self.loss_func = nn.CrossEntropyLoss(reduction='mean')
+ self.smoothing = smoothing
+ self.with_all = with_all
+
+ def forward(self, pred, batch):
+ pred = pred['res']
+ if isinstance(pred, dict): # for ABINet
+ loss = {}
+ loss_sum = []
+ for name, logits in pred.items():
+ if isinstance(logits, list):
+ logit_num = len(logits)
+ all_tgt = torch.cat([batch[1]] * logit_num, 0)
+ all_logits = torch.cat(logits, 0)
+ flt_logtis = all_logits.reshape([-1, all_logits.shape[2]])
+ flt_tgt = all_tgt.reshape([-1])
+ else:
+ flt_logtis = logits.reshape([-1, logits.shape[2]])
+ flt_tgt = batch[1].reshape([-1])
+ loss[name + '_loss'] = self.loss_func(flt_logtis, flt_tgt)
+ loss_sum.append(loss[name + '_loss'])
+ loss['loss'] = sum(loss_sum)
+ return loss
+ else:
+ if self.with_all: # for ViTSTR
+ tgt = batch[1]
+ pred = pred.reshape([-1, pred.shape[2]])
+ tgt = tgt.reshape([-1])
+ loss = self.loss_func(pred, tgt)
+ return {'loss': loss}
+ else: # for NRTR
+ max_len = batch[2].max()
+ tgt = batch[1][:, 1:2 + max_len]
+ pred = pred.reshape([-1, pred.shape[2]])
+ tgt = tgt.reshape([-1])
+ if self.smoothing:
+ eps = 0.1
+ pred.shape[1]
+ one_hot = F.one_hot(tgt, pred.shape[1])
+ one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (-1)
+ log_prb = F.log_softmax(pred, dim=1)
+ non_pad_mask = torch.not_equal(
+ tgt,
+ torch.zeros(tgt.shape,
+ dtype=tgt.dtype,
+ device=tgt.device))
+ loss = -(one_hot * log_prb).sum(dim=1)
+ loss = loss.masked_select(non_pad_mask).mean()
+ else:
+ loss = self.loss_func(pred, tgt)
+ return {'loss': loss}
diff --git a/openrec/losses/cppd_loss.py b/openrec/losses/cppd_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8af214c36a1cb21442786d68b1b328dc9547523
--- /dev/null
+++ b/openrec/losses/cppd_loss.py
@@ -0,0 +1,77 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class CPPDLoss(nn.Module):
+
+ def __init__(self,
+ smoothing=False,
+ ignore_index=100,
+ pos_len=False,
+ sideloss_weight=1.0,
+ max_len=25,
+ **kwargs):
+ super(CPPDLoss, self).__init__()
+ self.edge_ce = nn.CrossEntropyLoss(reduction='mean',
+ ignore_index=ignore_index)
+ self.char_node_ce = nn.CrossEntropyLoss(reduction='mean')
+ if pos_len:
+ self.pos_node_ce = nn.CrossEntropyLoss(reduction='mean',
+ ignore_index=ignore_index)
+ else:
+ self.pos_node_ce = nn.BCEWithLogitsLoss(reduction='mean')
+
+ self.smoothing = smoothing
+ self.ignore_index = ignore_index
+ self.pos_len = pos_len
+ self.sideloss_weight = sideloss_weight
+ self.max_len = max_len + 1
+
+ def label_smoothing_ce(self, preds, targets):
+ zeros_ = torch.zeros_like(targets)
+ ignore_index_ = zeros_ + self.ignore_index
+ non_pad_mask = torch.not_equal(targets, ignore_index_)
+
+ tgts = torch.where(targets == ignore_index_, zeros_, targets)
+ eps = 0.1
+ n_class = preds.shape[1]
+ one_hot = F.one_hot(tgts, preds.shape[1])
+ one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
+ log_prb = F.log_softmax(preds, dim=1)
+ loss = -(one_hot * log_prb).sum(dim=1)
+ loss = loss.masked_select(non_pad_mask).mean()
+ return loss
+
+ def forward(self, pred, batch):
+ node_feats, edge_feats = pred
+ node_tgt = batch[2]
+ char_tgt = batch[1]
+
+ # updated code
+ char_num_label = torch.clip(node_tgt[:, :-self.max_len].flatten(0, 1),
+ 0, node_feats[0].shape[-1] - 1)
+ loss_char_node = self.char_node_ce(node_feats[0].flatten(0, 1),
+ char_num_label)
+ if self.pos_len:
+ loss_pos_node = self.pos_node_ce(
+ node_feats[1].flatten(0, 1),
+ node_tgt[:, -self.max_len:].flatten(0, 1))
+ else:
+ loss_pos_node = self.pos_node_ce(
+ node_feats[1].flatten(0, 1),
+ node_tgt[:, -self.max_len:].flatten(0, 1).float())
+ loss_node = loss_char_node + loss_pos_node
+ # -----
+ edge_feats = edge_feats.flatten(0, 1)
+ char_tgt = char_tgt.flatten(0, 1)
+ if self.smoothing:
+ loss_edge = self.label_smoothing_ce(edge_feats, char_tgt)
+ else:
+ loss_edge = self.edge_ce(edge_feats, char_tgt)
+
+ return {
+ 'loss': self.sideloss_weight * loss_node + loss_edge,
+ 'loss_node': self.sideloss_weight * loss_node,
+ 'loss_edge': loss_edge,
+ }
diff --git a/openrec/losses/ctc_loss.py b/openrec/losses/ctc_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..de58b4b45a59c5cea031160fa0906c22601fe541
--- /dev/null
+++ b/openrec/losses/ctc_loss.py
@@ -0,0 +1,33 @@
+import torch
+from torch import nn
+
+
+class CTCLoss(nn.Module):
+
+ def __init__(self, use_focal_loss=False, zero_infinity=False, **kwargs):
+ super(CTCLoss, self).__init__()
+ self.loss_func = nn.CTCLoss(blank=0,
+ reduction='none',
+ zero_infinity=zero_infinity)
+ self.use_focal_loss = use_focal_loss
+
+ def forward(self, predicts, batch):
+ # predicts = predicts['res']
+
+ batch_size = predicts.size(0)
+ label, label_length = batch[1], batch[2]
+ predicts = predicts.log_softmax(2)
+ predicts = predicts.permute(1, 0, 2)
+ preds_lengths = torch.tensor([predicts.size(0)] * batch_size,
+ dtype=torch.long)
+ loss = self.loss_func(predicts, label, preds_lengths, label_length)
+
+ if self.use_focal_loss:
+ # Use torch.clamp to limit the range of loss, avoiding overflow in exponential calculation
+ clamped_loss = torch.clamp(loss, min=-20, max=20)
+ weight = 1 - torch.exp(-clamped_loss)
+ weight = torch.square(weight)
+ # Use torch.where to avoid multiplying by zero weight
+ loss = torch.where(weight > 0, loss * weight, loss)
+ loss = loss.mean()
+ return {'loss': loss}
diff --git a/openrec/losses/igtr_loss.py b/openrec/losses/igtr_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..90c7bcabd5371db3f22eb3fede0c585c5af5a31f
--- /dev/null
+++ b/openrec/losses/igtr_loss.py
@@ -0,0 +1,12 @@
+from torch import nn
+
+
+class IGTRLoss(nn.Module):
+
+ def __init__(self, **kwargs):
+ super(IGTRLoss, self).__init__()
+
+ def forward(self, predicts, batch):
+ if isinstance(predicts, list):
+ predicts = predicts[0]
+ return predicts
diff --git a/openrec/losses/lister_loss.py b/openrec/losses/lister_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd544cd8c2e0e7af16f5056c840627f19981ee1a
--- /dev/null
+++ b/openrec/losses/lister_loss.py
@@ -0,0 +1,14 @@
+from torch import nn
+
+
+class LISTERLoss(nn.Module):
+
+ def __init__(self, **kwargs):
+ super(LISTERLoss, self).__init__()
+
+ def forward(self, predicts, batch):
+ # predicts = predicts['res']
+ # loss = predicts
+ if isinstance(predicts, list):
+ predicts = predicts[0]
+ return predicts
diff --git a/openrec/losses/lpv_loss.py b/openrec/losses/lpv_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..000d399cc58bfc1df94a4925af0d3f8e0b0a3342
--- /dev/null
+++ b/openrec/losses/lpv_loss.py
@@ -0,0 +1,30 @@
+import torch.nn.functional as F
+from torch import nn
+
+
+class LPVLoss(nn.Module):
+
+ def __init__(self, label_smoothing=0.0, **kwargs):
+ super(LPVLoss, self).__init__()
+ self.label_smoothing = label_smoothing
+
+ def forward(self, preds, batch):
+ max_len = batch[2].max()
+ tgt = batch[1][:, 1:2 + max_len]
+
+ tgt = tgt.flatten(0, 1)
+ loss = 0
+ loss_dict = {}
+ for i, pred in enumerate(preds):
+ pred = pred.flatten(0, 1)
+ loss_i = F.cross_entropy(
+ pred,
+ tgt,
+ reduction='mean',
+ label_smoothing=self.label_smoothing,
+ ignore_index=pred.shape[1] + 1,
+ ) # self.loss_func(pred, tgt)
+ loss += loss_i
+ loss_dict['loss' + str(i)] = loss_i
+ loss_dict['loss'] = loss
+ return loss_dict
diff --git a/openrec/losses/mgp_loss.py b/openrec/losses/mgp_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..0951a5bccf8613a9d93ea6f045d264998205ebd3
--- /dev/null
+++ b/openrec/losses/mgp_loss.py
@@ -0,0 +1,34 @@
+from torch import nn
+
+
+class MGPLoss(nn.Module):
+
+ def __init__(self, only_char=False, **kwargs):
+ super(MGPLoss, self).__init__()
+ self.ce = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
+ self.only_char = only_char
+
+ def forward(self, pred, batch):
+ if self.only_char:
+ char_feats = pred
+ char_tgt = batch[1].flatten(0, 1)
+ char_loss = self.ce(char_feats.flatten(0, 1), char_tgt)
+ return {'loss': char_loss}
+ else:
+ return self.forward_all(pred, batch)
+
+ def forward_all(self, pred, batch):
+ char_feats, dpe_feats, wp_feats = pred
+ char_tgt = batch[1].flatten(0, 1)
+ dpe_tgt = batch[2].flatten(0, 1)
+ wp_tgt = batch[3].flatten(0, 1)
+ char_loss = self.ce(char_feats.flatten(0, 1), char_tgt)
+ dpe_loss = self.ce(dpe_feats.flatten(0, 1), dpe_tgt)
+ wp_loss = self.ce(wp_feats.flatten(0, 1), wp_tgt)
+ loss = char_loss + dpe_loss + wp_loss
+ return {
+ 'loss': loss,
+ 'char_loss': char_loss,
+ 'dpe_loss': dpe_loss,
+ 'wp_loss': wp_loss
+ }
diff --git a/openrec/losses/parseq_loss.py b/openrec/losses/parseq_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..49417bdfbb69f048e0d5a1b6a7eaf63260b81a2f
--- /dev/null
+++ b/openrec/losses/parseq_loss.py
@@ -0,0 +1,12 @@
+from torch import nn
+
+
+class PARSeqLoss(nn.Module):
+
+ def __init__(self, **kwargs):
+ super(PARSeqLoss, self).__init__()
+
+ def forward(self, predicts, batch):
+ # predicts = predicts['res']
+ loss, _ = predicts
+ return {'loss': loss}
diff --git a/openrec/losses/robustscanner_loss.py b/openrec/losses/robustscanner_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cca48d9b2ad75bce53e3fad7c51ff7f3646c393
--- /dev/null
+++ b/openrec/losses/robustscanner_loss.py
@@ -0,0 +1,20 @@
+from torch import nn
+
+
+class RobustScannerLoss(nn.Module):
+
+ def __init__(self, **kwargs):
+ super(RobustScannerLoss, self).__init__()
+ ignore_index = kwargs.get('ignore_index', 38)
+ self.loss_func = nn.CrossEntropyLoss(reduction='mean',
+ ignore_index=ignore_index)
+
+ def forward(self, pred, batch):
+ pred = pred[:, :-1, :]
+
+ label = batch[1][:, 1:].reshape([-1])
+
+ inputs = pred.reshape([-1, pred.shape[2]])
+
+ loss = self.loss_func(inputs, label)
+ return {'loss': loss}
diff --git a/openrec/losses/seed_loss.py b/openrec/losses/seed_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eceac40e3d0aac7d79cd8da4e002ff1fbbb76b3
--- /dev/null
+++ b/openrec/losses/seed_loss.py
@@ -0,0 +1,46 @@
+import torch.nn.functional as F
+from torch import nn
+import torch
+
+
+class CosineEmbeddingLoss(nn.Module):
+
+ def __init__(self, margin=0.0):
+ super(CosineEmbeddingLoss, self).__init__()
+ self.margin = margin
+ self.epsilon = 1e-12
+
+ def forward(self, x1, x2):
+ similarity = torch.sum(x1 * x2, axis=-1) / (
+ torch.norm(x1, dim=-1) * torch.norm(x2, dim=-1) + self.epsilon)
+ return (1 - similarity).mean()
+
+
+class SEEDLoss(nn.Module):
+
+ def __init__(self, label_smoothing=0.1, ignore_index=0, **kwargs):
+ super(SEEDLoss, self).__init__()
+ self.label_smoothing = label_smoothing
+ self.loss_sem = CosineEmbeddingLoss()
+
+ def forward(self, preds, batch):
+ embedding_vectors, pred = preds
+ max_len = batch[2].max()
+ tgt = batch[1][:, 1:2 + max_len]
+ pred = pred.flatten(0, 1)
+ tgt = tgt.reshape([-1])
+ loss = F.cross_entropy(
+ pred,
+ tgt,
+ reduction='mean',
+ label_smoothing=self.label_smoothing,
+ ignore_index=pred.shape[1] + 1,
+ ) # self.loss_func(pred, tgt)
+ sem_target = batch[3].float()
+
+ sem_loss = torch.sum(self.loss_sem(embedding_vectors, sem_target))
+ return {
+ 'loss': loss + 0.1 * sem_loss,
+ 'rec_loss': loss,
+ 'sem_loss': sem_loss
+ }
diff --git a/openrec/losses/smtr_loss.py b/openrec/losses/smtr_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..348ae104308b23e47e35cc6f3ce36cd72633c89a
--- /dev/null
+++ b/openrec/losses/smtr_loss.py
@@ -0,0 +1,12 @@
+from torch import nn
+
+
+class SMTRLoss(nn.Module):
+
+ def __init__(self, **kwargs):
+ super(SMTRLoss, self).__init__()
+
+ def forward(self, predicts, batch):
+ if isinstance(predicts, list):
+ predicts = predicts[0]
+ return predicts
diff --git a/openrec/losses/srn_loss.py b/openrec/losses/srn_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..52648ac697c2e85e9bcb79fd92c5e69e9ebc2c64
--- /dev/null
+++ b/openrec/losses/srn_loss.py
@@ -0,0 +1,40 @@
+import torch.nn.functional as F
+from torch import nn
+
+
+class SRNLoss(nn.Module):
+
+ def __init__(self, label_smoothing=0.0, **kwargs):
+ super(SRNLoss, self).__init__()
+ self.label_smoothing = label_smoothing
+
+ def forward(self, preds, batch):
+ pvam_preds, gsrm_preds, vsfd_preds = preds
+
+ label = batch[1].reshape([-1])
+
+ ignore_index = pvam_preds.shape[-1] + 1
+
+ loss_pvam = F.cross_entropy(pvam_preds,
+ label,
+ reduction='mean',
+ label_smoothing=self.label_smoothing,
+ ignore_index=ignore_index)
+ loss_gsrm = F.cross_entropy(gsrm_preds,
+ label,
+ reduction='mean',
+ label_smoothing=self.label_smoothing,
+ ignore_index=ignore_index)
+ loss_vsfd = F.cross_entropy(vsfd_preds,
+ label,
+ reduction='mean',
+ label_smoothing=self.label_smoothing,
+ ignore_index=ignore_index)
+
+ loss_dict = {}
+ loss_dict['loss_pvam'] = loss_pvam
+ loss_dict['loss_gsrm'] = loss_gsrm
+ loss_dict['loss_vsfd'] = loss_vsfd
+
+ loss_dict['loss'] = loss_pvam * 3.0 + loss_gsrm * 0.15 + loss_vsfd
+ return loss_dict
diff --git a/openrec/losses/visionlan_loss.py b/openrec/losses/visionlan_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b7591555d2e34a0dd1fdc2effcf6473a7e4f948
--- /dev/null
+++ b/openrec/losses/visionlan_loss.py
@@ -0,0 +1,58 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+def flatten_label(target):
+ label_flatten = []
+ label_length = []
+ for i in range(0, target.size()[0]):
+ cur_label = target[i].tolist()
+ label_flatten += cur_label[:cur_label.index(0) + 1]
+ label_length.append(cur_label.index(0) + 1)
+ label_flatten = torch.LongTensor(label_flatten)
+ label_length = torch.IntTensor(label_length)
+ return (label_flatten, label_length)
+
+
+def _flatten(sources, lengths):
+ return torch.cat([t[:l] for t, l in zip(sources, lengths)])
+
+
+class VisionLANLoss(nn.Module):
+
+ def __init__(self,
+ training_step='LA',
+ ratio_res=0.5,
+ ratio_sub=0.5,
+ **kwargs):
+ super(VisionLANLoss, self).__init__()
+ self.loss_func = nn.CrossEntropyLoss(reduction='mean')
+ self.ratio_res = ratio_res
+ self.ratio_sub = ratio_sub
+ assert training_step in ['LF_1', 'LF_2', 'LA']
+ self.training_step = training_step
+
+ def forward(self, pred, batch):
+ text_pre, text_rem, text_mas, _ = pred
+ target = batch[1].to(dtype=torch.int64)
+ label_flatten, length = flatten_label(target)
+ text_pre = _flatten(text_pre, length)
+ if self.training_step == 'LF_1':
+ loss = self.loss_func(text_pre, label_flatten.to(text_pre.device))
+ else:
+ target_res = batch[2].to(dtype=torch.int64)
+ target_sub = batch[3].to(dtype=torch.int64)
+ label_flatten_res, length_res = flatten_label(target_res)
+ label_flatten_sub, length_sub = flatten_label(target_sub)
+ text_rem = _flatten(text_rem, length_res)
+ text_mas = _flatten(text_mas, length_sub)
+ loss_ori = self.loss_func(text_pre,
+ label_flatten.to(text_pre.device))
+ loss_res = self.loss_func(text_rem,
+ label_flatten_res.to(text_rem.device))
+ loss_mas = self.loss_func(text_mas,
+ label_flatten_sub.to(text_mas.device))
+ loss = loss_ori + loss_res * self.ratio_res + loss_mas * self.ratio_sub
+
+ return {'loss': loss}
diff --git a/openrec/metrics/__init__.py b/openrec/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..281b6dcdd2ba345d1c87c6e572a91fe67edd5d42
--- /dev/null
+++ b/openrec/metrics/__init__.py
@@ -0,0 +1,19 @@
+import copy
+
+__all__ = ['build_metric']
+
+from .rec_metric import RecMetric
+from .rec_metric_gtc import RecGTCMetric
+from .rec_metric_long import RecMetricLong
+from .rec_metric_mgp import RecMPGMetric
+
+support_dict = ['RecMetric', 'RecMetricLong', 'RecGTCMetric', 'RecMPGMetric']
+
+
+def build_metric(config):
+ config = copy.deepcopy(config)
+ module_name = config.pop('name')
+ assert module_name in support_dict, Exception(
+ 'metric only support {}'.format(support_dict))
+ module_class = eval(module_name)(**config)
+ return module_class
diff --git a/openrec/metrics/__pycache__/__init__.cpython-38.pyc b/openrec/metrics/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4fdf54bbd23dcb830b69cad766f3ea8e996fd506
Binary files /dev/null and b/openrec/metrics/__pycache__/__init__.cpython-38.pyc differ
diff --git a/openrec/metrics/__pycache__/rec_metric.cpython-38.pyc b/openrec/metrics/__pycache__/rec_metric.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e969526b46de3a38606e05dad893179394bf9cde
Binary files /dev/null and b/openrec/metrics/__pycache__/rec_metric.cpython-38.pyc differ
diff --git a/openrec/metrics/__pycache__/rec_metric_gtc.cpython-38.pyc b/openrec/metrics/__pycache__/rec_metric_gtc.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14c09499e6f5b18817fb8be4f3327f5e71f4bb72
Binary files /dev/null and b/openrec/metrics/__pycache__/rec_metric_gtc.cpython-38.pyc differ
diff --git a/openrec/metrics/__pycache__/rec_metric_long.cpython-38.pyc b/openrec/metrics/__pycache__/rec_metric_long.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f06abf32ba236b908b98cd6bdc572ed7045ae14a
Binary files /dev/null and b/openrec/metrics/__pycache__/rec_metric_long.cpython-38.pyc differ
diff --git a/openrec/metrics/__pycache__/rec_metric_mgp.cpython-38.pyc b/openrec/metrics/__pycache__/rec_metric_mgp.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..790023dc78471159cc7cd913904afc51a0439afc
Binary files /dev/null and b/openrec/metrics/__pycache__/rec_metric_mgp.cpython-38.pyc differ
diff --git a/openrec/metrics/rec_metric.py b/openrec/metrics/rec_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6335420e485b80e041243313791a25e74c235e8
--- /dev/null
+++ b/openrec/metrics/rec_metric.py
@@ -0,0 +1,270 @@
+import string
+import numpy as np
+from rapidfuzz.distance import Levenshtein
+
+
+def match_ss(ss1, ss2):
+ s1_len = len(ss1)
+ for c_i in range(s1_len):
+ if ss1[c_i:] == ss2[:s1_len - c_i]:
+ return ss2[s1_len - c_i:]
+ return ss2
+
+
+def stream_match(text):
+ bs = len(text)
+ s_list = []
+ conf_list = []
+ for s_conf in text:
+ s_list.append(s_conf[0])
+ conf_list.append(s_conf[1])
+ s_n = bs
+ s_start = s_list[0][:-1]
+ s_new = s_start
+ for s_i in range(1, s_n):
+ s_start = match_ss(
+ s_start, s_list[s_i][1:-1] if s_i < s_n - 1 else s_list[s_i][1:])
+ s_new += s_start
+ return s_new, sum(conf_list) / bs
+
+
+class RecMetric(object):
+
+ def __init__(self,
+ main_indicator='acc',
+ is_filter=False,
+ is_lower=True,
+ ignore_space=True,
+ stream=False,
+ with_ratio=False,
+ max_len=25,
+ max_ratio=4,
+ **kwargs):
+ self.main_indicator = main_indicator
+ self.is_filter = is_filter
+ self.is_lower = is_lower
+ self.ignore_space = ignore_space
+ self.stream = stream
+ self.eps = 1e-5
+ self.with_ratio = with_ratio
+ self.max_len = max_len
+ self.max_ratio = max_ratio
+ self.reset()
+
+ def _normalize_text(self, text):
+ text = ''.join(
+ filter(lambda x: x in (string.digits + string.ascii_letters),
+ text))
+ return text
+
+ def __call__(self,
+ pred_label,
+ batch=None,
+ training=False,
+ *args,
+ **kwargs):
+ if self.with_ratio and not training:
+ return self.eval_all_metric(pred_label, batch)
+ else:
+ return self.eval_metric(pred_label)
+
+ def eval_metric(self, pred_label, *args, **kwargs):
+ preds, labels = pred_label
+ correct_num = 0
+ all_num = 0
+ norm_edit_dis = 0.0
+ for (pred, pred_conf), (target, _) in zip(preds, labels):
+ if self.stream:
+ assert len(labels) == 1
+ pred, _ = stream_match(preds)
+ if self.ignore_space:
+ pred = pred.replace(' ', '')
+ target = target.replace(' ', '')
+ if self.is_filter:
+ pred = self._normalize_text(pred)
+ target = self._normalize_text(target)
+ if self.is_lower:
+ pred = pred.lower()
+ target = target.lower()
+ norm_edit_dis += Levenshtein.normalized_distance(pred, target)
+ if pred == target:
+ correct_num += 1
+ all_num += 1
+ self.correct_num += correct_num
+ self.all_num += all_num
+ self.norm_edit_dis += norm_edit_dis
+ return {
+ 'acc': correct_num / (all_num + self.eps),
+ 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps),
+ }
+
+ def eval_all_metric(self, pred_label, batch=None, *args, **kwargs):
+ if self.with_ratio:
+ ratio = batch[-1]
+ preds, labels = pred_label
+ correct_num = 0
+ correct_num_real = 0
+ correct_num_lower = 0
+ correct_num_ignore_space = 0
+ correct_num_ignore_space_lower = 0
+ correct_num_ignore_space_symbol = 0
+ all_num = 0
+ norm_edit_dis = 0.0
+ each_len_num = [0 for _ in range(self.max_len)]
+ each_len_correct_num = [0 for _ in range(self.max_len)]
+ each_len_norm_edit_dis = [0 for _ in range(self.max_len)]
+ each_ratio_num = [0 for _ in range(self.max_ratio)]
+ each_ratio_correct_num = [0 for _ in range(self.max_ratio)]
+ each_ratio_norm_edit_dis = [0 for _ in range(self.max_ratio)]
+ for (pred, pred_conf), (target, _) in zip(preds, labels):
+ if self.stream:
+ assert len(labels) == 1
+ pred, _ = stream_match(preds)
+ if pred == target:
+ correct_num_real += 1
+
+ if pred.lower() == target.lower():
+ correct_num_lower += 1
+
+ if self.ignore_space:
+ pred = pred.replace(' ', '')
+ target = target.replace(' ', '')
+ if pred == target:
+ correct_num_ignore_space += 1
+
+ if pred.lower() == target.lower():
+ correct_num_ignore_space_lower += 1
+
+ if self.is_filter:
+ pred = self._normalize_text(pred)
+ target = self._normalize_text(target)
+ if pred == target:
+ correct_num_ignore_space_symbol += 1
+
+ if self.is_lower:
+ pred = pred.lower()
+ target = target.lower()
+ dis = Levenshtein.normalized_distance(pred, target)
+ norm_edit_dis += dis
+ ratio_i = ratio[all_num] - 1 if ratio[
+ all_num] < self.max_ratio else self.max_ratio - 1
+ len_i = max(0, min(self.max_len, len(target)) - 1)
+ if pred == target:
+ correct_num += 1
+ each_len_correct_num[len_i] += 1
+ each_ratio_correct_num[ratio_i] += 1
+ each_len_num[len_i] += 1
+ each_len_norm_edit_dis[len_i] += dis
+
+ each_ratio_num[ratio_i] += 1
+ each_ratio_norm_edit_dis[ratio_i] += dis
+ all_num += 1
+ self.correct_num += correct_num
+ self.correct_num_real += correct_num_real
+ self.correct_num_lower += correct_num_lower
+ self.correct_num_ignore_space += correct_num_ignore_space
+ self.correct_num_ignore_space_lower += correct_num_ignore_space_lower
+ self.correct_num_ignore_space_symbol += correct_num_ignore_space_symbol
+ self.all_num += all_num
+ self.norm_edit_dis += norm_edit_dis
+ self.each_len_num = self.each_len_num + np.array(each_len_num)
+ self.each_len_correct_num = self.each_len_correct_num + np.array(
+ each_len_correct_num)
+ self.each_len_norm_edit_dis = self.each_len_norm_edit_dis + np.array(
+ each_len_norm_edit_dis)
+ self.each_ratio_num = self.each_ratio_num + np.array(each_ratio_num)
+ self.each_ratio_correct_num = self.each_ratio_correct_num + np.array(
+ each_ratio_correct_num)
+ self.each_ratio_norm_edit_dis = self.each_ratio_norm_edit_dis + np.array(
+ each_ratio_norm_edit_dis)
+ return {
+ 'acc': correct_num / (all_num + self.eps),
+ 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps),
+ }
+
+ def get_metric(self, training=False):
+ """
+ return metrics {
+ 'acc': 0,
+ 'norm_edit_dis': 0,
+ }
+ """
+ if self.with_ratio and not training:
+ return self.get_all_metric()
+ acc = 1.0 * self.correct_num / (self.all_num + self.eps)
+ norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
+ num_samples = self.all_num
+ self.reset()
+ return {
+ 'acc': acc,
+ 'norm_edit_dis': norm_edit_dis,
+ 'num_samples': num_samples
+ }
+
+ def get_all_metric(self):
+ """
+ return metrics {
+ 'acc': 0,
+ 'norm_edit_dis': 0,
+ }
+ """
+ acc = 1.0 * self.correct_num / (self.all_num + self.eps)
+ acc_real = 1.0 * self.correct_num_real / (self.all_num + self.eps)
+ acc_lower = 1.0 * self.correct_num_lower / (self.all_num + self.eps)
+ acc_ignore_space = 1.0 * self.correct_num_ignore_space / (
+ self.all_num + self.eps)
+ acc_ignore_space_lower = 1.0 * self.correct_num_ignore_space_lower / (
+ self.all_num + self.eps)
+ acc_ignore_space_symbol = 1.0 * self.correct_num_ignore_space_symbol / (
+ self.all_num + self.eps)
+
+ norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
+ num_samples = self.all_num
+ each_len_acc = (self.each_len_correct_num /
+ (self.each_len_num + self.eps)).tolist()
+ each_len_norm_edit_dis = (1 -
+ ((self.each_len_norm_edit_dis) /
+ ((self.each_len_num) + self.eps))).tolist()
+ each_len_num = self.each_len_num.tolist()
+ each_ratio_acc = (self.each_ratio_correct_num /
+ (self.each_ratio_num + self.eps)).tolist()
+ each_ratio_norm_edit_dis = (1 - ((self.each_ratio_norm_edit_dis) / (
+ (self.each_ratio_num) + self.eps))).tolist()
+ each_ratio_num = self.each_ratio_num.tolist()
+ self.reset()
+ return {
+ 'acc': acc,
+ 'acc_real': acc_real,
+ 'acc_lower': acc_lower,
+ 'acc_ignore_space': acc_ignore_space,
+ 'acc_ignore_space_lower': acc_ignore_space_lower,
+ 'acc_ignore_space_symbol': acc_ignore_space_symbol,
+ 'acc_ignore_space_lower_symbol': acc,
+ 'each_len_num': each_len_num,
+ 'each_len_acc': each_len_acc,
+ 'each_len_norm_edit_dis': each_len_norm_edit_dis,
+ 'each_ratio_num': each_ratio_num,
+ 'each_ratio_acc': each_ratio_acc,
+ 'each_ratio_norm_edit_dis': each_ratio_norm_edit_dis,
+ 'norm_edit_dis': norm_edit_dis,
+ 'num_samples': num_samples
+ }
+
+ def reset(self):
+ self.correct_num = 0
+ self.all_num = 0
+ self.norm_edit_dis = 0
+ self.correct_num_real = 0
+ self.correct_num_lower = 0
+ self.correct_num_ignore_space = 0
+ self.correct_num_ignore_space_lower = 0
+ self.correct_num_ignore_space_symbol = 0
+ self.each_len_num = np.array([0 for _ in range(self.max_len)])
+ self.each_len_correct_num = np.array([0 for _ in range(self.max_len)])
+ self.each_len_norm_edit_dis = np.array(
+ [0. for _ in range(self.max_len)])
+ self.each_ratio_num = np.array([0 for _ in range(self.max_ratio)])
+ self.each_ratio_correct_num = np.array(
+ [0 for _ in range(self.max_ratio)])
+ self.each_ratio_norm_edit_dis = np.array(
+ [0. for _ in range(self.max_ratio)])
diff --git a/openrec/metrics/rec_metric_gtc.py b/openrec/metrics/rec_metric_gtc.py
new file mode 100644
index 0000000000000000000000000000000000000000..f023513b32429ac04e9782bbbc275522f295e170
--- /dev/null
+++ b/openrec/metrics/rec_metric_gtc.py
@@ -0,0 +1,58 @@
+from .rec_metric import RecMetric
+
+
+class RecGTCMetric(object):
+
+ def __init__(self,
+ main_indicator='acc',
+ is_filter=False,
+ ignore_space=True,
+ stream=False,
+ with_ratio=False,
+ max_len=25,
+ max_ratio=4,
+ **kwargs):
+ self.main_indicator = main_indicator
+ self.is_filter = is_filter
+ self.ignore_space = ignore_space
+ self.eps = 1e-5
+ self.gtc_metric = RecMetric(main_indicator=main_indicator,
+ is_filter=is_filter,
+ ignore_space=ignore_space,
+ stream=stream,
+ with_ratio=with_ratio,
+ max_len=max_len,
+ max_ratio=max_ratio)
+ self.ctc_metric = RecMetric(main_indicator=main_indicator,
+ is_filter=is_filter,
+ ignore_space=ignore_space,
+ stream=stream,
+ with_ratio=with_ratio,
+ max_len=max_len,
+ max_ratio=max_ratio)
+
+ def __call__(self,
+ pred_label,
+ batch=None,
+ training=False,
+ *args,
+ **kwargs):
+
+ ctc_metric = self.ctc_metric(pred_label[1], batch, training=training)
+ gtc_metric = self.gtc_metric(pred_label[0], batch, training=training)
+ ctc_metric['gtc_acc'] = gtc_metric['acc']
+ ctc_metric['gtc_norm_edit_dis'] = gtc_metric['norm_edit_dis']
+ return ctc_metric
+
+ def get_metric(self):
+ """
+ return metrics {
+ 'acc': 0,
+ 'norm_edit_dis': 0,
+ }
+ """
+ ctc_metric = self.ctc_metric.get_metric()
+ gtc_metric = self.gtc_metric.get_metric()
+ ctc_metric['gtc_acc'] = gtc_metric['acc']
+ ctc_metric['gtc_norm_edit_dis'] = gtc_metric['norm_edit_dis']
+ return ctc_metric
diff --git a/openrec/metrics/rec_metric_long.py b/openrec/metrics/rec_metric_long.py
new file mode 100644
index 0000000000000000000000000000000000000000..941cf1d76fefc73716cfce58157c8cb7b6241aec
--- /dev/null
+++ b/openrec/metrics/rec_metric_long.py
@@ -0,0 +1,142 @@
+import string
+
+import numpy as np
+from rapidfuzz.distance import Levenshtein
+
+from .rec_metric import stream_match
+
+# f_pred = open('pred_focal_subs_rand1_h2_bi_first.txt', 'w')
+
+
+class RecMetricLong(object):
+
+ def __init__(self,
+ main_indicator='acc',
+ is_filter=False,
+ ignore_space=True,
+ stream=False,
+ **kwargs):
+ self.main_indicator = main_indicator
+ self.is_filter = is_filter
+ self.ignore_space = ignore_space
+ self.stream = stream
+ self.eps = 1e-5
+ self.max_len = 201
+ self.reset()
+
+ def _normalize_text(self, text):
+ text = ''.join(
+ filter(lambda x: x in (string.digits + string.ascii_letters),
+ text))
+ return text.lower()
+
+ def __call__(self, pred_label, *args, **kwargs):
+ preds, labels = pred_label
+ correct_num = 0
+ correct_num_slice = 0
+ f_l_acc = 0
+ all_num = 0
+ norm_edit_dis = 0.0
+ len_acc = 0
+ each_len_num = [0 for _ in range(self.max_len)]
+ each_len_correct_num = [0 for _ in range(self.max_len)]
+ each_len_norm_edit_dis = [0 for _ in range(self.max_len)]
+ for (pred, pred_conf), (target, _) in zip(preds, labels):
+ if self.stream:
+ assert len(labels) == 1
+ pred, _ = stream_match(preds)
+ if self.ignore_space:
+ pred = pred.replace(' ', '')
+ target = target.replace(' ', '')
+ if self.is_filter:
+ pred = self._normalize_text(pred)
+ target = self._normalize_text(target)
+ dis = Levenshtein.normalized_distance(pred, target)
+ norm_edit_dis += dis
+ # print(pred, target)
+ if pred == target:
+ correct_num += 1
+ each_len_correct_num[len(target)] += 1
+ each_len_num[len(target)] += 1
+ each_len_norm_edit_dis[len(target)] += dis
+ # f_pred.write(pred+'\t'+target+'\t1'+'\n')
+ # print(pred, target, 1)
+ # else:
+ # f_pred.write(pred+'\t'+target+'\t0'+'\n')
+ # print(pred, target, 0)
+ if len(pred) >= 1 and len(target) >= 1:
+ if pred[0] == target[0] and pred[-1] == target[-1]:
+ f_l_acc += 1
+ if len(pred) == len(target):
+ len_acc += 1
+ if pred == target[:len(pred)]:
+ # if pred == target[-len(pred):]:
+ correct_num_slice += 1
+ all_num += 1
+ self.correct_num += correct_num
+ self.correct_num_slice += correct_num_slice
+ self.f_l_acc += f_l_acc
+ self.all_num += all_num
+ self.len_acc += len_acc
+ self.each_len_num = self.each_len_num + np.array(each_len_num)
+ self.each_len_correct_num = self.each_len_correct_num + np.array(
+ each_len_correct_num)
+ self.each_len_norm_edit_dis = self.each_len_norm_edit_dis + np.array(
+ each_len_norm_edit_dis)
+ self.norm_edit_dis += norm_edit_dis
+ return {
+ 'acc': correct_num / (all_num + self.eps),
+ 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps),
+ }
+
+ def get_metric(self):
+ """
+ return metrics {
+ 'acc': 0,
+ 'norm_edit_dis': 0,
+ }
+ """
+ acc = 1.0 * self.correct_num / (self.all_num + self.eps)
+ acc_slice = 1.0 * self.correct_num_slice / (self.all_num + self.eps)
+ f_l_acc = 1.0 * self.f_l_acc / (self.all_num + self.eps)
+ len_acc = 1.0 * self.len_acc / (self.all_num + self.eps)
+ norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
+ each_len_acc = (self.each_len_correct_num /
+ (self.each_len_num + self.eps)).tolist()
+ # each_len_acc_25 = each_len_acc[:26]
+ # each_len_acc_26 = each_len_acc[26:]
+ each_len_norm_edit_dis = (1 -
+ ((self.each_len_norm_edit_dis) /
+ ((self.each_len_num) + self.eps))).tolist()
+ # each_len_norm_edit_dis_25 = each_len_norm_edit_dis[:26]
+ # each_len_norm_edit_dis_26 = each_len_norm_edit_dis[26:]
+ each_len_num = self.each_len_num.tolist()
+ all_num = self.all_num
+ self.reset()
+ return {
+ 'acc': acc,
+ 'norm_edit_dis': norm_edit_dis,
+ 'acc_slice': acc_slice,
+ 'f_l_acc': f_l_acc,
+ 'len_acc': len_acc,
+ 'each_len_num': each_len_num,
+ 'each_len_acc': each_len_acc,
+ # "each_len_acc_25": each_len_acc_25,
+ # "each_len_acc_26": each_len_acc_26,
+ 'each_len_norm_edit_dis': each_len_norm_edit_dis,
+ # "each_len_norm_edit_dis_25":each_len_norm_edit_dis_25,
+ # "each_len_norm_edit_dis_26":each_len_norm_edit_dis_26,
+ 'all_num': all_num
+ }
+
+ def reset(self):
+ self.correct_num = 0
+ self.all_num = 0
+ self.norm_edit_dis = 0
+ self.correct_num_slice = 0
+ self.each_len_num = np.array([0 for _ in range(self.max_len)])
+ self.each_len_correct_num = np.array([0 for _ in range(self.max_len)])
+ self.each_len_norm_edit_dis = np.array(
+ [0. for _ in range(self.max_len)])
+ self.f_l_acc = 0
+ self.len_acc = 0
diff --git a/openrec/metrics/rec_metric_mgp.py b/openrec/metrics/rec_metric_mgp.py
new file mode 100644
index 0000000000000000000000000000000000000000..3232b4c84b95180b65ad5d5367aed48afa16062c
--- /dev/null
+++ b/openrec/metrics/rec_metric_mgp.py
@@ -0,0 +1,93 @@
+from .rec_metric import RecMetric
+
+
+class RecMPGMetric(object):
+
+ def __init__(self,
+ main_indicator='acc',
+ is_filter=False,
+ ignore_space=True,
+ stream=False,
+ with_ratio=False,
+ max_len=25,
+ max_ratio=4,
+ **kwargs):
+ self.main_indicator = main_indicator
+ self.is_filter = is_filter
+ self.ignore_space = ignore_space
+ self.eps = 1e-5
+ self.char_metric = RecMetric(main_indicator=main_indicator,
+ is_filter=is_filter,
+ ignore_space=ignore_space,
+ stream=stream,
+ with_ratio=with_ratio,
+ max_len=max_len,
+ max_ratio=max_ratio)
+ self.bpe_metric = RecMetric(main_indicator=main_indicator,
+ is_filter=is_filter,
+ ignore_space=ignore_space,
+ stream=stream,
+ with_ratio=with_ratio,
+ max_len=max_len,
+ max_ratio=max_ratio)
+
+ self.wp_metric = RecMetric(main_indicator=main_indicator,
+ is_filter=is_filter,
+ ignore_space=ignore_space,
+ stream=stream,
+ with_ratio=with_ratio,
+ max_len=max_len,
+ max_ratio=max_ratio)
+ self.final_metric = RecMetric(main_indicator=main_indicator,
+ is_filter=is_filter,
+ ignore_space=ignore_space,
+ stream=stream,
+ with_ratio=with_ratio,
+ max_len=max_len,
+ max_ratio=max_ratio)
+
+ def __call__(self,
+ pred_label,
+ batch=None,
+ training=False,
+ *args,
+ **kwargs):
+
+ char_metric = self.char_metric((pred_label[0], pred_label[-1]),
+ batch,
+ training=training)
+ bpe_metric = self.bpe_metric((pred_label[1], pred_label[-1]),
+ batch,
+ training=training)
+ wp_metric = self.wp_metric((pred_label[2], pred_label[-1]),
+ batch,
+ training=training)
+ final_metric = self.final_metric((pred_label[3], pred_label[-1]),
+ batch,
+ training=training)
+ final_metric['char_acc'] = char_metric['acc']
+ final_metric['char_norm_edit_dis'] = char_metric['norm_edit_dis']
+ final_metric['bpe_acc'] = bpe_metric['acc']
+ final_metric['bpe_norm_edit_dis'] = bpe_metric['norm_edit_dis']
+ final_metric['wp_acc'] = wp_metric['acc']
+ final_metric['wp_norm_edit_dis'] = wp_metric['norm_edit_dis']
+ return final_metric
+
+ def get_metric(self):
+ """
+ return metrics {
+ 'acc': 0,
+ 'norm_edit_dis': 0,
+ }
+ """
+ char_metric = self.char_metric.get_metric()
+ bpe_metric = self.bpe_metric.get_metric()
+ wp_metric = self.wp_metric.get_metric()
+ final_metric = self.final_metric.get_metric()
+ final_metric['char_acc'] = char_metric['acc']
+ final_metric['char_norm_edit_dis'] = char_metric['norm_edit_dis']
+ final_metric['bpe_acc'] = bpe_metric['acc']
+ final_metric['bpe_norm_edit_dis'] = bpe_metric['norm_edit_dis']
+ final_metric['wp_acc'] = wp_metric['acc']
+ final_metric['wp_norm_edit_dis'] = wp_metric['norm_edit_dis']
+ return final_metric
diff --git a/openrec/modeling/__init__.py b/openrec/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..be3232486d1d244de91e7e632ccb0224c00fd5f3
--- /dev/null
+++ b/openrec/modeling/__init__.py
@@ -0,0 +1,11 @@
+import copy
+
+from .base_recognizer import BaseRecognizer
+
+__all__ = ['build_model']
+
+
+def build_model(config):
+ config = copy.deepcopy(config)
+ rec_model = BaseRecognizer(config)
+ return rec_model
diff --git a/openrec/modeling/__pycache__/__init__.cpython-38.pyc b/openrec/modeling/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0fbf5b02250ef60b346beaaf27b45ef38391a6e4
Binary files /dev/null and b/openrec/modeling/__pycache__/__init__.cpython-38.pyc differ
diff --git a/openrec/modeling/__pycache__/base_recognizer.cpython-38.pyc b/openrec/modeling/__pycache__/base_recognizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae7070c5e642892d57d4c20ab2698670f4b229d6
Binary files /dev/null and b/openrec/modeling/__pycache__/base_recognizer.cpython-38.pyc differ
diff --git a/openrec/modeling/__pycache__/common.cpython-38.pyc b/openrec/modeling/__pycache__/common.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..712b8021c2207993cd5d4c093e386e2d079b43aa
Binary files /dev/null and b/openrec/modeling/__pycache__/common.cpython-38.pyc differ
diff --git a/openrec/modeling/base_recognizer.py b/openrec/modeling/base_recognizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..97bf66f6691485b46a9df05ef47341fb5b37ae93
--- /dev/null
+++ b/openrec/modeling/base_recognizer.py
@@ -0,0 +1,69 @@
+import torch
+from torch import nn
+
+from openrec.modeling.decoders import build_decoder
+from openrec.modeling.encoders import build_encoder
+from openrec.modeling.transforms import build_transform
+
+__all__ = ['BaseRecognizer']
+
+
+class BaseRecognizer(nn.Module):
+
+ def __init__(self, config):
+ """the module for OCR.
+
+ args:
+ config (dict): the super parameters for module.
+ """
+ super(BaseRecognizer, self).__init__()
+ in_channels = config.get('in_channels', 3)
+ self.use_wd = config.get('use_wd', True)
+ # build transfrom,
+ # for rec, transfrom can be TPS,None
+ if 'Transform' not in config or config['Transform'] is None:
+ self.use_transform = False
+ else:
+ self.use_transform = True
+ config['Transform']['in_channels'] = in_channels
+ self.transform = build_transform(config['Transform'])
+ in_channels = self.transform.out_channels
+
+ # build backbone
+ if 'Encoder' not in config or config['Encoder'] is None:
+ self.use_encoder = False
+ else:
+ self.use_encoder = True
+ config['Encoder']['in_channels'] = in_channels
+ self.encoder = build_encoder(config['Encoder'])
+ in_channels = self.encoder.out_channels
+
+ # build decoder
+ if 'Decoder' not in config or config['Decoder'] is None:
+ self.use_decoder = False
+ else:
+ self.use_decoder = True
+ config['Decoder']['in_channels'] = in_channels
+ self.decoder = build_decoder(config['Decoder'])
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ if self.use_wd:
+ if hasattr(self.encoder, 'no_weight_decay'):
+ no_weight_decay = self.encoder.no_weight_decay()
+ else:
+ no_weight_decay = {}
+ if hasattr(self.decoder, 'no_weight_decay'):
+ no_weight_decay.update(self.decoder.no_weight_decay())
+ return no_weight_decay
+ else:
+ return {}
+
+ def forward(self, x, data=None):
+ if self.use_transform:
+ x = self.transform(x)
+ if self.use_encoder:
+ x = self.encoder(x)
+ if self.use_decoder:
+ x = self.decoder(x, data=data)
+ return x
diff --git a/openrec/modeling/common.py b/openrec/modeling/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..3330d26c670848c6ac93c6f9234d8473b82fbe30
--- /dev/null
+++ b/openrec/modeling/common.py
@@ -0,0 +1,238 @@
+import torch
+import torch.nn as nn
+
+
+class GELU(nn.Module):
+
+ def __init__(self, inplace=True):
+ super(GELU, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return torch.nn.functional.gelu(x)
+
+
+class Swish(nn.Module):
+
+ def __init__(self, inplace=True):
+ super(Swish, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ if self.inplace:
+ x.mul_(torch.sigmoid(x))
+ return x
+ else:
+ return x * torch.sigmoid(x)
+
+
+class Activation(nn.Module):
+
+ def __init__(self, act_type, inplace=True):
+ super(Activation, self).__init__()
+ act_type = act_type.lower()
+ if act_type == 'relu':
+ self.act = nn.ReLU(inplace=inplace)
+ elif act_type == 'relu6':
+ self.act = nn.ReLU6(inplace=inplace)
+ elif act_type == 'sigmoid':
+ self.act = nn.Sigmoid()
+ elif act_type == 'hard_sigmoid':
+ self.act = nn.Hardsigmoid(inplace)
+ elif act_type == 'hard_swish':
+ self.act = nn.Hardswish(inplace=inplace)
+ elif act_type == 'leakyrelu':
+ self.act = nn.LeakyReLU(inplace=inplace)
+ elif act_type == 'gelu':
+ self.act = GELU(inplace=inplace)
+ elif act_type == 'swish':
+ self.act = Swish(inplace=inplace)
+ else:
+ raise NotImplementedError
+
+ def forward(self, inputs):
+ return self.act(inputs)
+
+
+def drop_path(x,
+ drop_prob: float = 0.0,
+ training: bool = False,
+ scale_by_keep: bool = True):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0], ) + (1, ) * (
+ x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks)."""
+
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
+
+
+class Identity(nn.Module):
+
+ def __init__(self):
+ super(Identity, self).__init__()
+
+ def forward(self, input):
+ return input
+
+
+class Mlp(nn.Module):
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+
+ def __init__(self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads,
+ C // self.num_heads).permute(2, 0, 3, 1, 4))
+ q, k, v = qkv[0], qkv[1], qkv[
+ 2] # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding."""
+
+ def __init__(self,
+ img_size=[32, 128],
+ patch_size=[4, 4],
+ in_chans=3,
+ embed_dim=768):
+ super().__init__()
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] //
+ patch_size[0])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=patch_size)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert (
+ H == self.img_size[0] and W == self.img_size[1]
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
diff --git a/openrec/modeling/decoders/__init__.py b/openrec/modeling/decoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f570fd002fa2b5886b1abdc16ebeaba2eac4a416
--- /dev/null
+++ b/openrec/modeling/decoders/__init__.py
@@ -0,0 +1,109 @@
+import torch.nn as nn
+
+__all__ = ['build_decoder']
+
+
+def build_decoder(config):
+ # rec decoder
+ from .abinet_decoder import ABINetDecoder
+ from .aster_decoder import ASTERDecoder
+ from .cdistnet_decoder import CDistNetDecoder
+ from .cppd_decoder import CPPDDecoder
+ from .rctc_decoder import RCTCDecoder
+ from .ctc_decoder import CTCDecoder
+ from .dan_decoder import DANDecoder
+ from .igtr_decoder import IGTRDecoder
+ from .lister_decoder import LISTERDecoder
+ from .lpv_decoder import LPVDecoder
+ from .mgp_decoder import MGPDecoder
+ from .nrtr_decoder import NRTRDecoder
+ from .parseq_decoder import PARSeqDecoder
+ from .robustscanner_decoder import RobustScannerDecoder
+ from .sar_decoder import SARDecoder
+ from .smtr_decoder import SMTRDecoder
+ from .smtr_decoder_nattn import SMTRDecoderNumAttn
+ from .srn_decoder import SRNDecoder
+ from .visionlan_decoder import VisionLANDecoder
+ from .matrn_decoder import MATRNDecoder
+ from .cam_decoder import CAMDecoder
+ from .ote_decoder import OTEDecoder
+ from .bus_decoder import BUSDecoder
+
+ support_dict = [
+ 'CTCDecoder', 'NRTRDecoder', 'CPPDDecoder', 'ABINetDecoder',
+ 'CDistNetDecoder', 'VisionLANDecoder', 'PARSeqDecoder', 'IGTRDecoder',
+ 'SMTRDecoder', 'LPVDecoder', 'SARDecoder', 'RobustScannerDecoder',
+ 'SRNDecoder', 'ASTERDecoder', 'RCTCDecoder', 'LISTERDecoder',
+ 'GTCDecoder', 'SMTRDecoderNumAttn', 'MATRNDecoder', 'MGPDecoder',
+ 'DANDecoder', 'CAMDecoder', 'OTEDecoder', 'BUSDecoder'
+ ]
+
+ module_name = config.pop('name')
+ assert module_name in support_dict, Exception(
+ 'decoder only support {}'.format(support_dict))
+ module_class = eval(module_name)(**config)
+ return module_class
+
+
+class GTCDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ gtc_decoder,
+ ctc_decoder,
+ detach=True,
+ infer_gtc=False,
+ out_channels=0,
+ **kwargs):
+ super(GTCDecoder, self).__init__()
+ self.detach = detach
+ self.infer_gtc = infer_gtc
+ if infer_gtc:
+ gtc_decoder['out_channels'] = out_channels[0]
+ ctc_decoder['out_channels'] = out_channels[1]
+ gtc_decoder['in_channels'] = in_channels
+ ctc_decoder['in_channels'] = in_channels
+ self.gtc_decoder = build_decoder(gtc_decoder)
+ else:
+ ctc_decoder['in_channels'] = in_channels
+ ctc_decoder['out_channels'] = out_channels
+ self.ctc_decoder = build_decoder(ctc_decoder)
+
+ def forward(self, x, data=None):
+ ctc_pred = self.ctc_decoder(x.detach() if self.detach else x,
+ data=data)
+ if self.training or self.infer_gtc:
+ gtc_pred = self.gtc_decoder(x.flatten(2).transpose(1, 2),
+ data=data)
+ return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred}
+ else:
+ return ctc_pred
+
+
+class GTCDecoderTwo(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ gtc_decoder,
+ ctc_decoder,
+ infer_gtc=False,
+ out_channels=0,
+ **kwargs):
+ super(GTCDecoderTwo, self).__init__()
+ self.infer_gtc = infer_gtc
+ gtc_decoder['out_channels'] = out_channels[0]
+ ctc_decoder['out_channels'] = out_channels[1]
+ gtc_decoder['in_channels'] = in_channels
+ ctc_decoder['in_channels'] = in_channels
+ self.gtc_decoder = build_decoder(gtc_decoder)
+ self.ctc_decoder = build_decoder(ctc_decoder)
+
+ def forward(self, x, data=None):
+ x_ctc, x_gtc = x
+ ctc_pred = self.ctc_decoder(x_ctc, data=data)
+ if self.training or self.infer_gtc:
+ gtc_pred = self.gtc_decoder(x_gtc.flatten(2).transpose(1, 2),
+ data=data)
+ return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred}
+ else:
+ return ctc_pred
diff --git a/openrec/modeling/decoders/__pycache__/__init__.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0383b4213a39c3c84e44c727759c4666fd29931b
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/__init__.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/abinet_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/abinet_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a1a8dd4d7f991c462a2ff57ab5cebe7d3151262
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/abinet_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/aster_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/aster_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14a31d3bb7a2dca484605103f307afd352db51ba
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/aster_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/bus_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/bus_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c83affe19141db092fcaa05c7a9206a73edd2a44
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/bus_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/cam_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/cam_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c794e73a1733a06507d2ee052f1860b23c7ab2bf
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/cam_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/cdistnet_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/cdistnet_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..012d854b58ee2b3160ba4679f4496f91ae19ab09
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/cdistnet_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/cppd_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/cppd_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f8ddce40676344e1fa07bd3d2d4e666fc1bc439
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/cppd_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/ctc_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/ctc_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18906017f0ad7da3bd643bf8f84cef72ba4041bd
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/ctc_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/dan_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/dan_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f84b498ab2fd8b6e4314ca65bb16f28dd92829e9
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/dan_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/igtr_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/igtr_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d038e18892559004758bbbb3d559aed48c1379ca
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/igtr_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/lister_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/lister_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03033459ffc454e08cd2370a3d2d814dd83b3050
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/lister_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/lpv_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/lpv_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2813901ff79c3c905c357aa6d642897b7d0a7565
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/lpv_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/matrn_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/matrn_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba3ae82524386e7e3c89d4c2b12fb514b87d5579
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/matrn_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/mgp_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/mgp_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f5fdda7e75cfab81d25c6240b49bbe3080c556a
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/mgp_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/nrtr_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/nrtr_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9348f1e63872f5ca071e23489949b8b57d5cf961
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/nrtr_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/ote_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/ote_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f439c3b310e118c333ea273767bfdb67504a8d8e
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/ote_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/parseq_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/parseq_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf514480dea70be6a1c80229a72bf9333d025310
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/parseq_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/rctc_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/rctc_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92c40570b05d55d5927ca36a2467b132c1228218
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/rctc_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/robustscanner_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/robustscanner_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62351b1aba19fcf5b9732148813cfd1b18615c80
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/robustscanner_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/sar_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/sar_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..186d19276191caf5136406998800fd490806cb66
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/sar_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/smtr_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/smtr_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f14488e3ddfd0c11076716134193f804ed12f2b
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/smtr_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/smtr_decoder_nattn.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/smtr_decoder_nattn.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d33e416f55ef3e6e422e26edde1803fa8c2eb58
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/smtr_decoder_nattn.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/srn_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/srn_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cea3afc72acfadbda66f1f08afdefb351224e84a
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/srn_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/__pycache__/visionlan_decoder.cpython-38.pyc b/openrec/modeling/decoders/__pycache__/visionlan_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8ef6d035b1b44241c470aaa427a89dae41807ea
Binary files /dev/null and b/openrec/modeling/decoders/__pycache__/visionlan_decoder.cpython-38.pyc differ
diff --git a/openrec/modeling/decoders/abinet_decoder.py b/openrec/modeling/decoders/abinet_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..41480d91ad6a552455247b4c53427372f21e8717
--- /dev/null
+++ b/openrec/modeling/decoders/abinet_decoder.py
@@ -0,0 +1,283 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, TransformerBlock
+
+
+class BCNLanguage(nn.Module):
+
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_layers=4,
+ dim_feedforward=2048,
+ dropout=0.0,
+ max_length=25,
+ detach=True,
+ num_classes=37,
+ ):
+ super().__init__()
+ self.d_model = d_model
+ self.detach = detach
+ self.max_length = max_length + 1
+
+ self.proj = nn.Linear(num_classes, d_model, False)
+ self.token_encoder = PositionalEncoding(dropout=0.1,
+ dim=d_model,
+ max_len=self.max_length)
+ self.pos_encoder = PositionalEncoding(dropout=0,
+ dim=d_model,
+ max_len=self.max_length)
+ self.decoder = nn.ModuleList([
+ TransformerBlock(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ attention_dropout_rate=dropout,
+ residual_dropout_rate=dropout,
+ with_self_attn=False,
+ with_cross_attn=True,
+ ) for i in range(num_layers)
+ ])
+
+ self.cls = nn.Linear(d_model, num_classes)
+
+ def forward(self, tokens, lengths):
+ """
+ Args:
+ tokens: (N, T, C) where T is length, N is batch size and C is classes number
+ lengths: (N,)
+ """
+ if self.detach:
+ tokens = tokens.detach()
+ embed = self.proj(tokens) # (N, T, E)
+ embed = self.token_encoder(embed) # (N, T, E)
+ mask = _get_mask(lengths, self.max_length) # (N, 1, T, T)
+ zeros = embed.new_zeros(*embed.shape)
+ qeury = self.pos_encoder(zeros)
+ for decoder_layer in self.decoder:
+ qeury = decoder_layer(qeury, embed, cross_mask=mask)
+ output = qeury # (N, T, E)
+
+ logits = self.cls(output) # (N, T, C)
+ return output, logits
+
+
+def encoder_layer(in_c, out_c, k=3, s=2, p=1):
+ return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
+ nn.BatchNorm2d(out_c), nn.ReLU(True))
+
+
+class DecoderUpsample(nn.Module):
+
+ def __init__(self, in_c, out_c, k=3, s=1, p=1, mode='nearest') -> None:
+ super().__init__()
+ self.align_corners = None if mode == 'nearest' else True
+ self.mode = mode
+ # nn.Upsample(size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners),
+ self.w = nn.Sequential(
+ nn.Conv2d(in_c, out_c, k, s, p),
+ nn.BatchNorm2d(out_c),
+ nn.ReLU(True),
+ )
+
+ def forward(self, x, size):
+ x = F.interpolate(x,
+ size=size,
+ mode=self.mode,
+ align_corners=self.align_corners)
+ return self.w(x)
+
+
+class PositionAttention(nn.Module):
+
+ def __init__(self,
+ max_length,
+ in_channels=512,
+ num_channels=64,
+ mode='nearest',
+ **kwargs):
+ super().__init__()
+ self.max_length = max_length
+ self.k_encoder = nn.Sequential(
+ encoder_layer(in_channels, num_channels, s=(1, 2)),
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
+ )
+ self.k_decoder = nn.ModuleList([
+ DecoderUpsample(num_channels, num_channels, mode=mode),
+ DecoderUpsample(num_channels, num_channels, mode=mode),
+ DecoderUpsample(num_channels, num_channels, mode=mode),
+ DecoderUpsample(num_channels, in_channels, mode=mode),
+ ])
+
+ self.pos_encoder = PositionalEncoding(dropout=0,
+ dim=in_channels,
+ max_len=max_length)
+ self.project = nn.Linear(in_channels, in_channels)
+
+ def forward(self, x, query=None):
+ N, E, H, W = x.size()
+ k, v = x, x # (N, E, H, W)
+
+ # calculate key vector
+ features = []
+ size_decoder = []
+ for i in range(0, len(self.k_encoder)):
+ size_decoder.append(k.shape[2:])
+ k = self.k_encoder[i](k)
+ features.append(k)
+ for i in range(0, len(self.k_decoder) - 1):
+ k = self.k_decoder[i](k, size=size_decoder[-(i + 1)])
+ k = k + features[len(self.k_decoder) - 2 - i]
+ k = self.k_decoder[-1](k, size=size_decoder[0]) # (N, E, H, W)
+ # calculate query vector
+ # TODO q=f(q,k)
+ zeros = x.new_zeros(
+ (N, self.max_length, E)) if query is None else query # (N, T, E)
+ q = self.pos_encoder(zeros) # (N, T, E)
+ q = self.project(q) # (N, T, E)
+
+ # calculate attention
+ attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
+ attn_scores = attn_scores / (E**0.5)
+ attn_scores = F.softmax(attn_scores, dim=-1)
+
+ # (N, E, H, W) -> (N, H, W, E) -> (N, (H*W), E)
+ v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
+ attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
+ return attn_vecs, attn_scores.view(N, -1, H, W)
+
+
+class ABINetDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ nhead=8,
+ num_layers=3,
+ dim_feedforward=2048,
+ dropout=0.1,
+ max_length=25,
+ iter_size=3,
+ **kwargs):
+ super().__init__()
+ self.max_length = max_length + 1
+ d_model = in_channels
+ self.pos_encoder = PositionalEncoding(dropout=0.1, dim=d_model)
+ self.encoder = nn.ModuleList([
+ TransformerBlock(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ attention_dropout_rate=dropout,
+ residual_dropout_rate=dropout,
+ with_self_attn=True,
+ with_cross_attn=False,
+ ) for _ in range(num_layers)
+ ])
+ self.decoder = PositionAttention(
+ max_length=self.max_length, # additional stop token
+ in_channels=d_model,
+ num_channels=d_model // 8,
+ mode='nearest',
+ )
+ self.out_channels = out_channels
+ self.cls = nn.Linear(d_model, self.out_channels)
+ self.iter_size = iter_size
+ if iter_size > 0:
+ self.language = BCNLanguage(
+ d_model=d_model,
+ nhead=nhead,
+ num_layers=4,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ max_length=max_length,
+ num_classes=self.out_channels,
+ )
+ # alignment
+ self.w_att_align = nn.Linear(2 * d_model, d_model)
+ self.cls_align = nn.Linear(d_model, self.out_channels)
+
+ def forward(self, x, data=None):
+ # bs, c, h, w
+ x = x.permute([0, 2, 3, 1]) # bs, h, w, c
+ _, H, W, C = x.shape
+ # assert H % 8 == 0 and W % 16 == 0, 'The height and width should be multiples of 8 and 16.'
+ feature = x.flatten(1, 2) # bs, h*w, c
+ feature = self.pos_encoder(feature) # bs, h*w, c
+ for encoder_layer in self.encoder:
+ feature = encoder_layer(feature)
+ # bs, h*w, c
+ feature = feature.reshape([-1, H, W, C]).permute(0, 3, 1,
+ 2) # bs, c, h, w
+ v_feature, _ = self.decoder(feature) # (bs[N], T, E)
+ vis_logits = self.cls(v_feature) # (bs[N], T, E)
+ align_lengths = _get_length(vis_logits)
+ align_logits = vis_logits
+ all_l_res, all_a_res = [], []
+ for _ in range(self.iter_size):
+ tokens = F.softmax(align_logits, dim=-1)
+ lengths = torch.clamp(
+ align_lengths, 2,
+ self.max_length) # TODO: move to language model
+ l_feature, l_logits = self.language(tokens, lengths)
+
+ # alignment
+ all_l_res.append(l_logits)
+ fuse = torch.cat((l_feature, v_feature), -1)
+ f_att = torch.sigmoid(self.w_att_align(fuse))
+ output = f_att * v_feature + (1 - f_att) * l_feature
+ align_logits = self.cls_align(output)
+
+ align_lengths = _get_length(align_logits)
+ all_a_res.append(align_logits)
+ if self.training:
+ return {
+ 'align': all_a_res,
+ 'lang': all_l_res,
+ 'vision': vis_logits
+ }
+ else:
+ return F.softmax(align_logits, -1)
+
+
+def _get_length(logit):
+ """Greed decoder to obtain length from logit."""
+ out = logit.argmax(dim=-1) == 0
+ non_zero_mask = out.int() != 0
+ mask_max_values, mask_max_indices = torch.max(non_zero_mask.int(), dim=-1)
+ mask_max_indices[mask_max_values == 0] = -1
+ out = mask_max_indices + 1
+ return out
+
+
+def _get_mask(length, max_length):
+ """Generate a square mask for the sequence.
+
+ The masked positions are filled with float('-inf'). Unmasked positions are
+ filled with float(0.0).
+ """
+ length = length.unsqueeze(-1)
+ N = length.size(0)
+ grid = torch.arange(0, max_length, device=length.device).unsqueeze(0)
+ zero_mask = torch.zeros([N, max_length],
+ dtype=torch.float32,
+ device=length.device)
+ inf_mask = torch.full([N, max_length],
+ float('-inf'),
+ dtype=torch.float32,
+ device=length.device)
+ diag_mask = torch.diag(
+ torch.full([max_length],
+ float('-inf'),
+ dtype=torch.float32,
+ device=length.device),
+ diagonal=0,
+ )
+ mask = torch.where(grid >= length, inf_mask, zero_mask)
+ mask = mask.unsqueeze(1) + diag_mask
+ return mask.unsqueeze(1)
diff --git a/openrec/modeling/decoders/aster_decoder.py b/openrec/modeling/decoders/aster_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..207910b655116769be8efada600e09641d47fbda
--- /dev/null
+++ b/openrec/modeling/decoders/aster_decoder.py
@@ -0,0 +1,170 @@
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from torch.nn import init
+
+
+class Embedding(nn.Module):
+
+ def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
+ super(Embedding, self).__init__()
+ self.in_timestep = in_timestep
+ self.in_planes = in_planes
+ self.embed_dim = embed_dim
+ self.mid_dim = mid_dim
+ self.eEmbed = nn.Linear(
+ in_timestep * in_planes,
+ self.embed_dim) # Embed encoder output to a word-embedding like
+
+ def forward(self, x):
+ x = x.flatten(1)
+ x = self.eEmbed(x)
+ return x
+
+
+class Attn_Rnn_Block(nn.Module):
+
+ def __init__(self, featdim, hiddendim, embedding_dim, out_channels,
+ attndim):
+ super(Attn_Rnn_Block, self).__init__()
+
+ self.attndim = attndim
+ self.embedding_dim = embedding_dim
+ self.feat_embed = nn.Linear(featdim, attndim)
+ self.hidden_embed = nn.Linear(hiddendim, attndim)
+ self.attnfeat_embed = nn.Linear(attndim, 1)
+ self.gru = nn.GRU(input_size=featdim + self.embedding_dim,
+ hidden_size=hiddendim,
+ batch_first=True)
+ self.fc = nn.Linear(hiddendim, out_channels)
+ self.init_weights()
+
+ def init_weights(self):
+ init.normal_(self.hidden_embed.weight, std=0.01)
+ init.constant_(self.hidden_embed.bias, 0)
+ init.normal_(self.attnfeat_embed.weight, std=0.01)
+ init.constant_(self.attnfeat_embed.bias, 0)
+
+ def _attn(self, feat, h_state):
+ b, t, _ = feat.shape
+ feat = self.feat_embed(feat)
+ h_state = self.hidden_embed(h_state.squeeze(0)).unsqueeze(1)
+ h_state = h_state.expand(b, t, self.attndim)
+ sumTanh = torch.tanh(feat + h_state)
+ attn_w = self.attnfeat_embed(sumTanh).squeeze(-1)
+ attn_w = F.softmax(attn_w, dim=1).unsqueeze(1)
+ # [B,1,25]
+ return attn_w
+
+ def forward(self, feat, h_state, label_input):
+
+ attn_w = self._attn(feat, h_state)
+
+ attn_feat = attn_w @ feat
+ attn_feat = attn_feat.squeeze(1)
+
+ output, h_state = self.gru(
+ torch.cat([label_input, attn_feat], 1).unsqueeze(1), h_state)
+ pred = self.fc(output)
+
+ return pred, h_state
+
+
+class ASTERDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ embedding_dim=256,
+ hiddendim=256,
+ attndim=256,
+ max_len=25,
+ seed=False,
+ time_step=32,
+ **kwargs):
+ super(ASTERDecoder, self).__init__()
+ self.num_classes = out_channels
+ self.bos = out_channels - 2
+ self.eos = 0
+ self.padding_idx = out_channels - 1
+ self.seed = seed
+ if seed:
+ self.embeder = Embedding(
+ in_timestep=time_step,
+ in_planes=in_channels,
+ )
+ self.word_embedding = nn.Embedding(self.num_classes,
+ embedding_dim,
+ padding_idx=self.padding_idx)
+
+ self.attndim = attndim
+ self.hiddendim = hiddendim
+ self.max_seq_len = max_len + 1
+
+ self.featdim = in_channels
+
+ self.attn_rnn_block = Attn_Rnn_Block(
+ featdim=self.featdim,
+ hiddendim=hiddendim,
+ embedding_dim=embedding_dim,
+ out_channels=out_channels - 2,
+ attndim=attndim,
+ )
+ self.embed_fc = nn.Linear(300, self.hiddendim)
+
+ def get_initial_state(self, embed, tile_times=1):
+ assert embed.shape[1] == 300
+ state = self.embed_fc(embed) # N * sDim
+ if tile_times != 1:
+ state = state.unsqueeze(1)
+ trans_state = state.transpose(0, 1)
+ state = trans_state.tile([tile_times, 1, 1])
+ trans_state = state.transpose(0, 1)
+ state = trans_state.reshape(-1, self.hiddendim)
+ state = state.unsqueeze(0) # 1 * N * sDim
+ return state
+
+ def forward(self, feat, data=None):
+ # b,25,512
+ b = feat.size(0)
+ if self.seed:
+ embedding_vectors = self.embeder(feat)
+ h_state = self.get_initial_state(embedding_vectors)
+ else:
+ h_state = torch.zeros(1, b, self.hiddendim).to(feat.device)
+ outputs = []
+ if self.training:
+ label = data[0]
+ label_embedding = self.word_embedding(label) # [B,25,256]
+ tokens = label_embedding[:, 0, :]
+ max_len = data[1].max() + 1
+ else:
+ tokens = torch.full([b, 1],
+ self.bos,
+ device=feat.device,
+ dtype=torch.long)
+ tokens = self.word_embedding(tokens.squeeze(1))
+ max_len = self.max_seq_len
+ pred, h_state = self.attn_rnn_block(feat, h_state, tokens)
+ outputs.append(pred)
+
+ dec_seq = torch.full((feat.shape[0], max_len),
+ self.padding_idx,
+ dtype=torch.int64,
+ device=feat.get_device())
+ dec_seq[:, :1] = torch.argmax(pred, dim=-1)
+ for i in range(1, max_len):
+ if not self.training:
+ max_idx = torch.argmax(pred, dim=-1).squeeze(1)
+ tokens = self.word_embedding(max_idx)
+ dec_seq[:, i] = max_idx
+ if (dec_seq == self.eos).any(dim=-1).all():
+ break
+ else:
+ tokens = label_embedding[:, i, :]
+ pred, h_state = self.attn_rnn_block(feat, h_state, tokens)
+ outputs.append(pred)
+ preds = torch.cat(outputs, 1)
+ if self.seed and self.training:
+ return [embedding_vectors, preds]
+ return preds if self.training else F.softmax(preds, -1)
diff --git a/openrec/modeling/decoders/bus_decoder.py b/openrec/modeling/decoders/bus_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bf6f7feceedd4e7a32691fca1c69adf2bd38d59
--- /dev/null
+++ b/openrec/modeling/decoders/bus_decoder.py
@@ -0,0 +1,133 @@
+"""This code is refer from:
+https://github.com/jjwei66/BUSNet
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .nrtr_decoder import PositionalEncoding, TransformerBlock
+from .abinet_decoder import _get_mask, _get_length
+
+
+class BUSDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ nhead=8,
+ num_layers=4,
+ dim_feedforward=2048,
+ dropout=0.1,
+ max_length=25,
+ ignore_index=100,
+ pretraining=False,
+ detach=True):
+ super().__init__()
+ d_model = in_channels
+ self.ignore_index = ignore_index
+ self.pretraining = pretraining
+ self.d_model = d_model
+ self.detach = detach
+ self.max_length = max_length + 1 # additional stop token
+ self.out_channels = out_channels
+ # --------------------------------------------------------------------------
+ # decoder specifics
+ self.proj = nn.Linear(out_channels, d_model, False)
+ self.token_encoder = PositionalEncoding(dropout=0.1,
+ dim=d_model,
+ max_len=self.max_length)
+ self.pos_encoder = PositionalEncoding(dropout=0.1,
+ dim=d_model,
+ max_len=self.max_length)
+
+ self.decoder = nn.ModuleList([
+ TransformerBlock(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ attention_dropout_rate=dropout,
+ residual_dropout_rate=dropout,
+ with_self_attn=False,
+ with_cross_attn=True,
+ ) for i in range(num_layers)
+ ])
+
+ v_mask = torch.empty((1, 1, d_model))
+ l_mask = torch.empty((1, 1, d_model))
+ self.v_mask = nn.Parameter(v_mask)
+ self.l_mask = nn.Parameter(l_mask)
+ torch.nn.init.uniform_(self.v_mask, -0.001, 0.001)
+ torch.nn.init.uniform_(self.l_mask, -0.001, 0.001)
+
+ v_embeding = torch.empty((1, 1, d_model))
+ l_embeding = torch.empty((1, 1, d_model))
+ self.v_embeding = nn.Parameter(v_embeding)
+ self.l_embeding = nn.Parameter(l_embeding)
+ torch.nn.init.uniform_(self.v_embeding, -0.001, 0.001)
+ torch.nn.init.uniform_(self.l_embeding, -0.001, 0.001)
+ self.cls = nn.Linear(d_model, out_channels)
+
+ def forward_decoder(self, q, x, mask=None):
+ for decoder_layer in self.decoder:
+ q = decoder_layer(q, x, cross_mask=mask)
+ output = q # (N, T, E)
+ logits = self.cls(output) # (N, T, C)
+ return logits
+
+ def forward(self, img_feat, data=None):
+ """
+ Args:
+ tokens: (N, T, C) where T is length, N is batch size and C is classes number
+ lengths: (N,)
+ """
+ img_feat = img_feat + self.v_embeding
+ B, L, C = img_feat.shape
+
+ # --------------------------------------------------------------------------
+ # decoder procedure
+ T = self.max_length
+ zeros = img_feat.new_zeros((B, T, C))
+ zeros_len = img_feat.new_zeros(B)
+ query = self.pos_encoder(zeros)
+
+ # 1. vision decode
+ v_embed = torch.cat((img_feat, self.l_mask.repeat(B, T, 1)),
+ dim=1) # v
+ padding_mask = _get_mask(
+ self.max_length + zeros_len,
+ self.max_length) # 对tokens长度以外的padding # B, maxlen maxlen
+ v_mask = torch.zeros((1, 1, self.max_length, L),
+ device=img_feat.device).tile([B, 1, 1,
+ 1]) # maxlen L
+ mask = torch.cat((v_mask, padding_mask), 3)
+ v_logits = self.forward_decoder(query, v_embed, mask=mask)
+
+ # 2. language decode
+ if self.training and self.pretraining:
+ tgt = torch.where(data[0] == self.ignore_index, 0, data[0])
+ tokens = F.one_hot(tgt, num_classes=self.out_channels)
+ tokens = tokens.float()
+ lengths = data[-1]
+ else:
+ tokens = torch.softmax(v_logits, dim=-1)
+ lengths = _get_length(v_logits)
+ tokens = tokens.detach()
+ token_embed = self.proj(tokens) # (N, T, E)
+ token_embed = self.token_encoder(token_embed) # (T, N, E)
+ token_embed = token_embed + self.l_embeding
+
+ padding_mask = _get_mask(lengths,
+ self.max_length) # 对tokens长度以外的padding
+ mask = torch.cat((v_mask, padding_mask), 3)
+ l_embed = torch.cat((self.v_mask.repeat(B, L, 1), token_embed), dim=1)
+ l_logits = self.forward_decoder(query, l_embed, mask=mask)
+
+ # 3. vision language decode
+ vl_embed = torch.cat((img_feat, token_embed), dim=1)
+ vl_logits = self.forward_decoder(query, vl_embed, mask=mask)
+
+ if self.training:
+ return {'align': [vl_logits], 'lang': l_logits, 'vision': v_logits}
+ else:
+ return F.softmax(vl_logits, -1)
diff --git a/openrec/modeling/decoders/cam_decoder.py b/openrec/modeling/decoders/cam_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cb47345076e6a537c5f5bd1258c0e92ab7ffdc7
--- /dev/null
+++ b/openrec/modeling/decoders/cam_decoder.py
@@ -0,0 +1,43 @@
+import torch.nn as nn
+
+from .nrtr_decoder import NRTRDecoder
+
+
+class CAMDecoder(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ nhead=None,
+ num_encoder_layers=6,
+ beam_size=0,
+ num_decoder_layers=6,
+ max_len=25,
+ attention_dropout_rate=0.0,
+ residual_dropout_rate=0.1,
+ scale_embedding=True,
+ ):
+ super().__init__()
+
+ self.decoder = NRTRDecoder(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ nhead=nhead,
+ num_encoder_layers=num_encoder_layers,
+ beam_size=beam_size,
+ num_decoder_layers=num_decoder_layers,
+ max_len=max_len,
+ attention_dropout_rate=attention_dropout_rate,
+ residual_dropout_rate=residual_dropout_rate,
+ scale_embedding=scale_embedding,
+ )
+
+ def forward(self, x, data=None):
+ dec_in = x['refined_feat']
+ dec_output = self.decoder(dec_in, data=data)
+ x['rec_output'] = dec_output
+ if self.training:
+ return x
+ else:
+ return dec_output
diff --git a/openrec/modeling/decoders/cdistnet_decoder.py b/openrec/modeling/decoders/cdistnet_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5604b787b6ca6c271267906608a17dfcf488c0a
--- /dev/null
+++ b/openrec/modeling/decoders/cdistnet_decoder.py
@@ -0,0 +1,334 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from openrec.modeling.decoders.nrtr_decoder import Embeddings, PositionalEncoding, TransformerBlock # , Beam
+from openrec.modeling.decoders.visionlan_decoder import Transformer_Encoder
+
+
+def generate_square_subsequent_mask(sz):
+ r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
+ Unmasked positions are filled with float(0.0).
+ """
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
+ mask = (mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
+ mask == 1, float(0.0)))
+ return mask
+
+
+class SEM_Pre(nn.Module):
+
+ def __init__(
+ self,
+ d_model=512,
+ dst_vocab_size=40,
+ residual_dropout_rate=0.1,
+ ):
+ super(SEM_Pre, self).__init__()
+ self.embedding = Embeddings(d_model=d_model, vocab=dst_vocab_size)
+
+ self.positional_encoding = PositionalEncoding(
+ dropout=residual_dropout_rate,
+ dim=d_model,
+ )
+
+ def forward(self, tgt):
+ tgt = self.embedding(tgt)
+ tgt = self.positional_encoding(tgt)
+ tgt_mask = generate_square_subsequent_mask(tgt.shape[1]).to(tgt.device)
+ return tgt, tgt_mask
+
+
+class POS_Pre(nn.Module):
+
+ def __init__(
+ self,
+ d_model=512,
+ ):
+ super(POS_Pre, self).__init__()
+ self.pos_encoding = PositionalEncoding(
+ dropout=0.1,
+ dim=d_model,
+ )
+ self.linear1 = nn.Linear(d_model, d_model)
+ self.linear2 = nn.Linear(d_model, d_model)
+
+ self.norm2 = nn.LayerNorm(d_model)
+
+ def forward(self, tgt):
+ pos = tgt.new_zeros(*tgt.shape)
+ pos = self.pos_encoding(pos)
+
+ pos2 = self.linear2(F.relu(self.linear1(pos)))
+ pos = self.norm2(pos + pos2)
+ return pos
+
+
+class DSF(nn.Module):
+
+ def __init__(self, d_model, fusion_num):
+ super(DSF, self).__init__()
+ self.w_att = nn.Linear(fusion_num * d_model, d_model)
+
+ def forward(self, l_feature, v_feature):
+ """
+ Args:
+ l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
+ v_feature: (N, T, E) shape the same as l_feature
+ l_lengths: (N,)
+ v_lengths: (N,)
+ """
+ f = torch.cat((l_feature, v_feature), dim=2)
+ f_att = torch.sigmoid(self.w_att(f))
+ output = f_att * v_feature + (1 - f_att) * l_feature
+
+ return output
+
+
+class MDCDP(nn.Module):
+ r"""
+ Multi-Domain CharacterDistance Perception
+ """
+
+ def __init__(self, d_model, n_head, d_inner, num_layers):
+ super(MDCDP, self).__init__()
+
+ self.num_layers = num_layers
+
+ # step 1 SAE
+ self.layers_pos = nn.ModuleList([
+ TransformerBlock(d_model, n_head, d_inner)
+ for _ in range(num_layers)
+ ])
+
+ # step 2 CBI:
+ self.layers2 = nn.ModuleList([
+ TransformerBlock(
+ d_model,
+ n_head,
+ d_inner,
+ with_self_attn=False,
+ with_cross_attn=True,
+ ) for _ in range(num_layers)
+ ])
+ self.layers3 = nn.ModuleList([
+ TransformerBlock(
+ d_model,
+ n_head,
+ d_inner,
+ with_self_attn=False,
+ with_cross_attn=True,
+ ) for _ in range(num_layers)
+ ])
+
+ # step 3 :DSF
+ self.dynamic_shared_fusion = DSF(d_model, 2)
+
+ def forward(
+ self,
+ sem,
+ vis,
+ pos,
+ tgt_mask=None,
+ memory_mask=None,
+ ):
+
+ for i in range(self.num_layers):
+ # ----------step 1 -----------: SAE: Self-Attention Enhancement
+ pos = self.layers_pos[i](pos, self_mask=tgt_mask)
+
+ # ----------step 2 -----------: CBI: Cross-Branch Interaction
+
+ # CBI-V
+ pos_vis = self.layers2[i](
+ pos,
+ vis,
+ cross_mask=memory_mask,
+ )
+
+ # CBI-S
+ pos_sem = self.layers3[i](
+ pos,
+ sem,
+ cross_mask=tgt_mask,
+ )
+
+ # ----------step 3 -----------: DSF: Dynamic Shared Fusion
+ pos = self.dynamic_shared_fusion(pos_vis, pos_sem)
+
+ output = pos
+ return output
+
+
+class ConvBnRelu(nn.Module):
+ # adapt padding for kernel_size change
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ conv=nn.Conv2d,
+ stride=2,
+ inplace=True,
+ ):
+ super().__init__()
+ p_size = [int(k // 2) for k in kernel_size]
+ # p_size = int(kernel_size//2)
+ self.conv = conv(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=p_size,
+ )
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.relu = nn.ReLU(inplace=inplace)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ x = self.relu(x)
+ return x
+
+
+class CDistNetDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ n_head=None,
+ num_encoder_blocks=3,
+ num_decoder_blocks=3,
+ beam_size=0,
+ max_len=25,
+ residual_dropout_rate=0.1,
+ add_conv=False,
+ **kwargs):
+ super(CDistNetDecoder, self).__init__()
+ dst_vocab_size = out_channels
+ self.ignore_index = dst_vocab_size - 1
+ self.bos = dst_vocab_size - 2
+ self.eos = 0
+ self.beam_size = beam_size
+ self.max_len = max_len
+ self.add_conv = add_conv
+ d_model = in_channels
+ dim_feedforward = d_model * 4
+ n_head = n_head if n_head is not None else d_model // 32
+
+ if add_conv:
+ self.convbnrelu = ConvBnRelu(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=(1, 3),
+ stride=(1, 2),
+ )
+ if num_encoder_blocks > 0:
+ self.positional_encoding = PositionalEncoding(
+ dropout=0.1,
+ dim=d_model,
+ )
+ self.trans_encoder = Transformer_Encoder(
+ n_layers=num_encoder_blocks,
+ n_head=n_head,
+ d_model=d_model,
+ d_inner=dim_feedforward,
+ )
+ else:
+ self.trans_encoder = None
+ self.semantic_branch = SEM_Pre(
+ d_model=d_model,
+ dst_vocab_size=dst_vocab_size,
+ residual_dropout_rate=residual_dropout_rate,
+ )
+ self.positional_branch = POS_Pre(d_model=d_model)
+
+ self.mdcdp = MDCDP(d_model, n_head, dim_feedforward // 2,
+ num_decoder_blocks)
+ self._reset_parameters()
+
+ self.tgt_word_prj = nn.Linear(
+ d_model, dst_vocab_size - 2,
+ bias=False) # We don't predict nor
+ self.tgt_word_prj.weight.data.normal_(mean=0.0, std=d_model**-0.5)
+
+ def forward(self, x, data=None):
+ if self.add_conv:
+ x = self.convbnrelu(x)
+ # x = rearrange(x, "b c h w -> b (w h) c")
+ x = x.flatten(2).transpose(1, 2)
+ if self.trans_encoder is not None:
+ x = self.positional_encoding(x)
+ vis_feat = self.trans_encoder(x, src_mask=None)
+ else:
+ vis_feat = x
+ if self.training:
+ max_len = data[1].max()
+ tgt = data[0][:, :1 + max_len]
+ res = self.forward_train(vis_feat, tgt)
+ else:
+ if self.beam_size > 0:
+ res = self.forward_beam(vis_feat)
+ else:
+ res = self.forward_test(vis_feat)
+ return res
+
+ def forward_train(self, vis_feat, tgt):
+ sem_feat, sem_mask = self.semantic_branch(tgt)
+ pos_feat = self.positional_branch(sem_feat)
+ output = self.mdcdp(
+ sem_feat,
+ vis_feat,
+ pos_feat,
+ tgt_mask=sem_mask,
+ memory_mask=None,
+ )
+
+ logit = self.tgt_word_prj(output)
+ return logit
+
+ def forward_test(self, vis_feat):
+ bs = vis_feat.size(0)
+
+ dec_seq = torch.full(
+ (bs, self.max_len + 1),
+ self.ignore_index,
+ dtype=torch.int64,
+ device=vis_feat.device,
+ )
+ dec_seq[:, 0] = self.bos
+ logits = []
+ for len_dec_seq in range(0, self.max_len):
+ sem_feat, sem_mask = self.semantic_branch(dec_seq[:, :len_dec_seq +
+ 1])
+ pos_feat = self.positional_branch(sem_feat)
+ output = self.mdcdp(
+ sem_feat,
+ vis_feat,
+ pos_feat,
+ tgt_mask=sem_mask,
+ memory_mask=None,
+ )
+
+ dec_output = output[:, -1:, :]
+
+ word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1)
+ logits.append(word_prob)
+ if len_dec_seq < self.max_len:
+ # greedy decode. add the next token index to the target input
+ dec_seq[:, len_dec_seq + 1] = word_prob.squeeze(1).argmax(-1)
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
+ if (dec_seq == self.eos).any(dim=-1).all():
+ break
+ logits = torch.cat(logits, dim=1)
+ return logits
+
+ def forward_beam(self, x):
+ """Translation work in one batch."""
+ # to do
+
+ def _reset_parameters(self):
+ r"""Initiate parameters in the transformer model."""
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
diff --git a/openrec/modeling/decoders/cppd_decoder.py b/openrec/modeling/decoders/cppd_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb21d6fbf84be0067deef37a1f0864e809710fe7
--- /dev/null
+++ b/openrec/modeling/decoders/cppd_decoder.py
@@ -0,0 +1,393 @@
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.init import ones_, trunc_normal_, zeros_
+
+from openrec.modeling.common import DropPath, Identity, Mlp
+from openrec.modeling.decoders.nrtr_decoder import Embeddings
+
+
+class Attention(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, q, kv, key_mask=None):
+ N, C = kv.shape[1:]
+ QN = q.shape[1]
+ q = self.q(q).reshape([-1, QN, self.num_heads,
+ C // self.num_heads]).transpose(1, 2)
+ q = q * self.scale
+ k, v = self.kv(kv).reshape(
+ [-1, N, 2, self.num_heads,
+ C // self.num_heads]).permute(2, 0, 3, 1, 4)
+
+ attn = q.matmul(k.transpose(2, 3))
+
+ if key_mask is not None:
+ attn = attn + key_mask.unsqueeze(1)
+
+ attn = F.softmax(attn, -1)
+ # if not self.training:
+ # self.attn_map = attn
+ attn = self.attn_drop(attn)
+
+ x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C))
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class EdgeDecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=[0.0, 0.0],
+ act_layer=nn.GELU,
+ norm_layer='nn.LayerNorm',
+ epsilon=1e-6,
+ ):
+ super().__init__()
+
+ self.head_dim = dim // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path1 = DropPath(
+ drop_path[0]) if drop_path[0] > 0.0 else Identity()
+ self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
+ self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
+
+ self.p = nn.Linear(dim, dim)
+ self.cv = nn.Linear(dim, dim)
+ self.pv = nn.Linear(dim, dim)
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.p_proj = nn.Linear(dim, dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp_ratio = mlp_ratio
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, p, cv, pv):
+ pN = p.shape[1]
+ vN = cv.shape[1]
+ p_shortcut = p
+
+ p1 = self.p(p).reshape(
+ [-1, pN, self.num_heads,
+ self.dim // self.num_heads]).transpose(1, 2)
+ cv1 = self.cv(cv).reshape(
+ [-1, vN, self.num_heads,
+ self.dim // self.num_heads]).transpose(1, 2)
+ pv1 = self.pv(pv).reshape(
+ [-1, vN, self.num_heads,
+ self.dim // self.num_heads]).transpose(1, 2)
+
+ edge = F.softmax(p1.matmul(pv1.transpose(2, 3)), -1) # B h N N
+
+ p_c = (edge @ cv1).transpose(1, 2).reshape((-1, pN, self.dim))
+
+ x1 = self.norm1(p_shortcut + self.drop_path1(self.p_proj(p_c)))
+
+ x = self.norm2(x1 + self.drop_path1(self.mlp(x1)))
+ return x
+
+
+class DecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ epsilon=1e-6,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim, eps=epsilon)
+ self.mixer = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ self.norm2 = norm_layer(dim, eps=epsilon)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp_ratio = mlp_ratio
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, q, kv, key_mask=None):
+ x1 = self.norm1(q + self.drop_path(self.mixer(q, kv, key_mask)))
+ x = self.norm2(x1 + self.drop_path(self.mlp(x1)))
+ return x
+
+
+class CPPDDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_layer=2,
+ drop_path_rate=0.1,
+ max_len=25,
+ vis_seq=50,
+ iters=1,
+ pos_len=False,
+ ch=False,
+ rec_layer=1,
+ num_heads=None,
+ ds=False,
+ **kwargs):
+ super(CPPDDecoder, self).__init__()
+
+ self.out_channels = out_channels # none + 26 + 10
+ dim = in_channels
+ self.dim = dim
+ self.iters = iters
+ self.max_len = max_len + 1 # max_len + eos
+ self.pos_len = pos_len
+ self.ch = ch
+ self.char_node_embed = Embeddings(d_model=dim,
+ vocab=self.out_channels,
+ scale_embedding=True)
+ self.pos_node_embed = Embeddings(d_model=dim,
+ vocab=self.max_len,
+ scale_embedding=True)
+ dpr = np.linspace(0, drop_path_rate, num_layer + rec_layer)
+
+ self.char_node_decoder = nn.ModuleList([
+ DecoderLayer(
+ dim=dim,
+ num_heads=dim // 32 if num_heads is None else num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=dpr[i],
+ ) for i in range(num_layer)
+ ])
+ self.pos_node_decoder = nn.ModuleList([
+ DecoderLayer(
+ dim=dim,
+ num_heads=dim // 32 if num_heads is None else num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=dpr[i],
+ ) for i in range(num_layer)
+ ])
+
+ self.edge_decoder = nn.ModuleList([
+ DecoderLayer(
+ dim=dim,
+ num_heads=dim // 32 if num_heads is None else num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=1.0 if (rec_layer + i) % 2 != 0 else None,
+ drop_path=dpr[num_layer + i],
+ ) for i in range(rec_layer)
+ ])
+ self.rec_layer_num = rec_layer
+ self_mask = torch.tril(
+ torch.ones([self.max_len, self.max_len], dtype=torch.float32))
+ self_mask = torch.where(
+ self_mask > 0,
+ torch.zeros_like(self_mask, dtype=torch.float32),
+ torch.full([self.max_len, self.max_len],
+ float('-inf'),
+ dtype=torch.float32),
+ )
+ self.self_mask = self_mask.unsqueeze(0)
+ self.char_pos_embed = nn.Parameter(torch.zeros([1, self.max_len, dim],
+ dtype=torch.float32),
+ requires_grad=True)
+ self.ds = ds
+ if not self.ds:
+ self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim],
+ dtype=torch.float32),
+ requires_grad=True)
+ trunc_normal_(self.vis_pos_embed, std=0.02)
+ self.char_node_fc1 = nn.Linear(dim, max_len)
+
+ self.pos_node_fc1 = nn.Linear(dim, self.max_len)
+
+ self.edge_fc = nn.Linear(dim, self.out_channels)
+
+ trunc_normal_(self.char_pos_embed, std=0.02)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {
+ 'char_pos_embed', 'vis_pos_embed', 'char_node_embed',
+ 'pos_node_embed'
+ }
+
+ def forward(self, x, data=None):
+ if self.training:
+ return self.forward_train(x, data)
+ else:
+ return self.forward_test(x)
+
+ def forward_test(self, x):
+ if not self.ds:
+ visual_feats = x + self.vis_pos_embed
+ else:
+ visual_feats = x
+ bs = visual_feats.shape[0]
+
+ pos_node_embed = self.pos_node_embed(
+ torch.arange(self.max_len).cuda(
+ x.get_device())).unsqueeze(0) + self.char_pos_embed
+ pos_node_embed = torch.tile(pos_node_embed, [bs, 1, 1])
+
+ char_vis_node_query = visual_feats
+ pos_vis_node_query = torch.concat([pos_node_embed, visual_feats], 1)
+
+ for char_decoder_layer, pos_decoder_layer in zip(
+ self.char_node_decoder, self.pos_node_decoder):
+ char_vis_node_query = char_decoder_layer(char_vis_node_query,
+ char_vis_node_query)
+ pos_vis_node_query = pos_decoder_layer(
+ pos_vis_node_query, pos_vis_node_query[:, self.max_len:, :])
+
+ pos_node_query = pos_vis_node_query[:, :self.max_len, :]
+
+ char_vis_feats = char_vis_node_query
+ # pos_vis_feats = pos_vis_node_query[:, self.max_len :, :]
+
+ # pos_node_feats = self.edge_decoder(
+ # pos_node_query, char_vis_feats, pos_vis_feats
+ # ) # B, 26, dim
+
+ pos_node_feats = pos_node_query
+ for layer_i in range(self.rec_layer_num):
+ rec_layer = self.edge_decoder[layer_i]
+ if (self.rec_layer_num + layer_i) % 2 == 0:
+ pos_node_feats = rec_layer(pos_node_feats, pos_node_feats,
+ self.self_mask)
+ else:
+ pos_node_feats = rec_layer(pos_node_feats, char_vis_feats)
+ edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37
+
+ edge_logits = F.softmax(
+ edge_feats,
+ -1) # * F.sigmoid(pos_node_feats1.unsqueeze(-1)) # B, 26, 37
+
+ return edge_logits
+
+ def forward_train(self, x, targets=None):
+ if not self.ds:
+ visual_feats = x + self.vis_pos_embed
+ else:
+ visual_feats = x
+ bs = visual_feats.shape[0]
+
+ if self.ch:
+ char_node_embed = self.char_node_embed(targets[-2])
+ else:
+ char_node_embed = self.char_node_embed(
+ torch.arange(self.out_channels).cuda(
+ x.get_device())).unsqueeze(0)
+ char_node_embed = torch.tile(char_node_embed, [bs, 1, 1])
+ counting_char_num = char_node_embed.shape[1]
+ pos_node_embed = self.pos_node_embed(
+ torch.arange(self.max_len).cuda(
+ x.get_device())).unsqueeze(0) + self.char_pos_embed
+ pos_node_embed = torch.tile(pos_node_embed, [bs, 1, 1])
+
+ node_feats = []
+
+ char_vis_node_query = torch.concat([char_node_embed, visual_feats], 1)
+ pos_vis_node_query = torch.concat([pos_node_embed, visual_feats], 1)
+
+ for char_decoder_layer, pos_decoder_layer in zip(
+ self.char_node_decoder, self.pos_node_decoder):
+ char_vis_node_query = char_decoder_layer(
+ char_vis_node_query,
+ char_vis_node_query[:, counting_char_num:, :])
+ pos_vis_node_query = pos_decoder_layer(
+ pos_vis_node_query, pos_vis_node_query[:, self.max_len:, :])
+
+ char_node_query = char_vis_node_query[:, :counting_char_num, :]
+ pos_node_query = pos_vis_node_query[:, :self.max_len, :]
+
+ char_vis_feats = char_vis_node_query[:, counting_char_num:, :]
+
+ char_node_feats1 = self.char_node_fc1(char_node_query)
+ pos_node_feats1 = self.pos_node_fc1(pos_node_query)
+ if not self.pos_len:
+ diag_mask = torch.eye(pos_node_feats1.shape[1]).unsqueeze(0).tile(
+ [pos_node_feats1.shape[0], 1, 1])
+ pos_node_feats1 = (
+ pos_node_feats1 *
+ diag_mask.cuda(pos_node_feats1.get_device())).sum(-1)
+
+ node_feats.append(char_node_feats1)
+ node_feats.append(pos_node_feats1)
+
+ pos_node_feats = pos_node_query
+ for layer_i in range(self.rec_layer_num):
+ rec_layer = self.edge_decoder[layer_i]
+ if (self.rec_layer_num + layer_i) % 2 == 0:
+ pos_node_feats = rec_layer(pos_node_feats, pos_node_feats,
+ self.self_mask)
+ else:
+ pos_node_feats = rec_layer(pos_node_feats, char_vis_feats)
+ edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37
+
+ return node_feats, edge_feats
diff --git a/openrec/modeling/decoders/ctc_decoder.py b/openrec/modeling/decoders/ctc_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..aca917b8cba1c3aef050fedd5d46e588e17ad2c4
--- /dev/null
+++ b/openrec/modeling/decoders/ctc_decoder.py
@@ -0,0 +1,203 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from openrec.modeling.encoders.svtrnet import (
+ Block,
+ ConvBNLayer,
+ kaiming_normal_,
+ trunc_normal_,
+ zeros_,
+ ones_,
+)
+
+
+class Swish(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x * F.sigmoid(x)
+
+
+class EncoderWithSVTR(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ dims=64, # XS
+ depth=2,
+ hidden_dims=120,
+ use_guide=False,
+ num_heads=8,
+ qkv_bias=True,
+ mlp_ratio=2.0,
+ drop_rate=0.1,
+ attn_drop_rate=0.1,
+ drop_path=0.0,
+ kernel_size=[3, 3],
+ qk_scale=None,
+ use_pool=True,
+ ):
+ super(EncoderWithSVTR, self).__init__()
+ self.depth = depth
+ self.use_guide = use_guide
+ self.use_pool = use_pool
+ self.conv1 = ConvBNLayer(
+ in_channels,
+ in_channels // 8,
+ kernel_size=kernel_size,
+ padding=[kernel_size[0] // 2, kernel_size[1] // 2],
+ act=Swish,
+ bias=False)
+ self.conv2 = ConvBNLayer(in_channels // 8,
+ hidden_dims,
+ kernel_size=1,
+ act=Swish,
+ bias=False)
+
+ self.svtr_block = nn.ModuleList([
+ Block(
+ dim=hidden_dims,
+ num_heads=num_heads,
+ mixer='Global',
+ HW=None,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=Swish,
+ attn_drop=attn_drop_rate,
+ drop_path=drop_path,
+ norm_layer='nn.LayerNorm',
+ eps=1e-05,
+ prenorm=False,
+ ) for i in range(depth)
+ ])
+ self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
+ self.conv3 = ConvBNLayer(hidden_dims,
+ in_channels,
+ kernel_size=1,
+ act=Swish,
+ bias=False)
+ # last conv-nxn, the input is concat of input tensor and conv3 output tensor
+ self.conv4 = ConvBNLayer(
+ 2 * in_channels,
+ in_channels // 8,
+ kernel_size=kernel_size,
+ padding=[kernel_size[0] // 2, kernel_size[1] // 2],
+ act=Swish,
+ bias=False)
+
+ self.conv1x1 = ConvBNLayer(in_channels // 8,
+ dims,
+ kernel_size=1,
+ act=Swish,
+ bias=False)
+ self.out_channels = dims
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, mean=0, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ if isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+ if isinstance(m, nn.Conv2d):
+ kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+
+ def pool_h_2(self, x):
+ # x: B, C, H, W
+ x = x.mean(dim=2, keepdim=True)
+ x = F.avg_pool2d(x, kernel_size=(1, 2))
+ return x # B, C, 1, W//2
+
+ def forward(self, x):
+
+ if self.use_pool:
+ x = self.pool_h_2(x)
+
+ # for use guide
+ if self.use_guide:
+ z = x.detach()
+ else:
+ z = x
+ # for short cut
+ h = z
+ # reduce dim
+ z = self.conv1(z)
+ z = self.conv2(z)
+ # SVTR global block
+ B, C, H, W = z.shape
+ z = z.flatten(2).transpose(1, 2)
+ for blk in self.svtr_block:
+ z = blk(z)
+ z = self.norm(z)
+ # last stage
+ z = z.reshape(-1, H, W, C).permute(0, 3, 1, 2)
+ z = self.conv3(z)
+ z = torch.concat((h, z), dim=1)
+ z = self.conv1x1(self.conv4(z))
+ return z
+
+
+class CTCDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels=6625,
+ mid_channels=None,
+ return_feats=False,
+ svtr_encoder=None,
+ **kwargs):
+ super(CTCDecoder, self).__init__()
+ if svtr_encoder is not None:
+ svtr_encoder['in_channels'] = in_channels
+ self.svtr_encoder = EncoderWithSVTR(**svtr_encoder)
+ in_channels = self.svtr_encoder.out_channels
+ else:
+ self.svtr_encoder = None
+ if mid_channels is None:
+ self.fc = nn.Linear(
+ in_channels,
+ out_channels,
+ )
+ else:
+ self.fc1 = nn.Linear(
+ in_channels,
+ mid_channels,
+ )
+ self.fc2 = nn.Linear(
+ mid_channels,
+ out_channels,
+ )
+
+ self.out_channels = out_channels
+ self.mid_channels = mid_channels
+ self.return_feats = return_feats
+
+ def forward(self, x, data=None):
+
+ if self.svtr_encoder is not None:
+ x = self.svtr_encoder(x)
+ x = x.flatten(2).transpose(1, 2)
+
+ if self.mid_channels is None:
+ predicts = self.fc(x)
+ else:
+ x = self.fc1(x)
+ predicts = self.fc2(x)
+
+ if self.return_feats:
+ result = (x, predicts)
+ else:
+ result = predicts
+
+ if not self.training:
+ predicts = F.softmax(predicts, dim=2)
+ result = predicts
+
+ return result
diff --git a/openrec/modeling/decoders/dan_decoder.py b/openrec/modeling/decoders/dan_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4f083e101db6608c9e9231c9e47ec309511a29f
--- /dev/null
+++ b/openrec/modeling/decoders/dan_decoder.py
@@ -0,0 +1,203 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class CAM(nn.Module):
+ '''
+ Convolutional Alignment Module
+ '''
+
+ # Current version only supports input whose size is a power of 2, such as 32, 64, 128 etc.
+ # You can adapt it to any input size by changing the padding or stride.
+ def __init__(self,
+ channels_list=[64, 128, 256, 512],
+ strides_list=[[2, 2], [1, 1], [1, 1]],
+ in_shape=[8, 32],
+ maxT=25,
+ depth=4,
+ num_channels=128):
+ super(CAM, self).__init__()
+ # cascade multiscale features
+ fpn = []
+ for i in range(1, len(channels_list)):
+ fpn.append(
+ nn.Sequential(
+ nn.Conv2d(channels_list[i - 1], channels_list[i], (3, 3),
+ (strides_list[i - 1][0], strides_list[i - 1][1]),
+ 1), nn.BatchNorm2d(channels_list[i]),
+ nn.ReLU(True)))
+ self.fpn = nn.Sequential(*fpn)
+ # convolutional alignment
+ # convs
+ assert depth % 2 == 0, 'the depth of CAM must be a even number.'
+ # in_shape = scales[-1]
+ strides = []
+ conv_ksizes = []
+ deconv_ksizes = []
+ h, w = in_shape[0], in_shape[1]
+ for i in range(0, int(depth / 2)):
+ stride = [2] if 2**(depth / 2 - i) <= h else [1]
+ stride = stride + [2] if 2**(depth / 2 - i) <= w else stride + [1]
+ strides.append(stride)
+ conv_ksizes.append([3, 3])
+ deconv_ksizes.append([_**2 for _ in stride])
+ convs = [
+ nn.Sequential(
+ nn.Conv2d(channels_list[-1], num_channels,
+ tuple(conv_ksizes[0]), tuple(strides[0]),
+ (int((conv_ksizes[0][0] - 1) / 2),
+ int((conv_ksizes[0][1] - 1) / 2))),
+ nn.BatchNorm2d(num_channels), nn.ReLU(True))
+ ]
+ for i in range(1, int(depth / 2)):
+ convs.append(
+ nn.Sequential(
+ nn.Conv2d(num_channels, num_channels,
+ tuple(conv_ksizes[i]), tuple(strides[i]),
+ (int((conv_ksizes[i][0] - 1) / 2),
+ int((conv_ksizes[i][1] - 1) / 2))),
+ nn.BatchNorm2d(num_channels), nn.ReLU(True)))
+ self.convs = nn.Sequential(*convs)
+ # deconvs
+ deconvs = []
+ for i in range(1, int(depth / 2)):
+ deconvs.append(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ num_channels, num_channels,
+ tuple(deconv_ksizes[int(depth / 2) - i]),
+ tuple(strides[int(depth / 2) - i]),
+ (int(deconv_ksizes[int(depth / 2) - i][0] / 4.),
+ int(deconv_ksizes[int(depth / 2) - i][1] / 4.))),
+ nn.BatchNorm2d(num_channels), nn.ReLU(True)))
+ deconvs.append(
+ nn.Sequential(
+ nn.ConvTranspose2d(num_channels, maxT, tuple(deconv_ksizes[0]),
+ tuple(strides[0]),
+ (int(deconv_ksizes[0][0] / 4.),
+ int(deconv_ksizes[0][1] / 4.))),
+ nn.Sigmoid()))
+ self.deconvs = nn.Sequential(*deconvs)
+
+ def forward(self, input):
+ x = input[0]
+ for i in range(0, len(self.fpn)):
+ # print(self.fpn[i](x).shape, input[i+1].shape)
+ x = self.fpn[i](x) + input[i + 1]
+ conv_feats = []
+ for i in range(0, len(self.convs)):
+ x = self.convs[i](x)
+ conv_feats.append(x)
+ for i in range(0, len(self.deconvs) - 1):
+ x = self.deconvs[i](x)
+ x = x + conv_feats[len(conv_feats) - 2 - i]
+ x = self.deconvs[-1](x)
+ return x
+
+
+class CAMSimp(nn.Module):
+
+ def __init__(self, maxT=25, num_channels=128):
+ super(CAMSimp, self).__init__()
+ self.conv = nn.Sequential(nn.Conv2d(num_channels, maxT, 1, 1, 0),
+ nn.Sigmoid())
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+
+class DANDecoder(nn.Module):
+ '''
+ Decoupled Text Decoder
+ '''
+
+ def __init__(self,
+ out_channels,
+ in_channels,
+ use_cam=True,
+ max_len=25,
+ channels_list=[64, 128, 256, 512],
+ strides_list=[[2, 2], [1, 1], [1, 1]],
+ in_shape=[8, 32],
+ depth=4,
+ dropout=0.3,
+ **kwargs):
+ super(DANDecoder, self).__init__()
+ self.eos = 0
+ self.bos = out_channels - 2
+ self.ignore_index = out_channels - 1
+ nchannel = in_channels
+ self.nchannel = in_channels
+ self.use_cam = use_cam
+ if use_cam:
+ self.cam = CAM(channels_list=channels_list,
+ strides_list=strides_list,
+ in_shape=in_shape,
+ maxT=max_len + 1,
+ depth=depth,
+ num_channels=nchannel)
+ else:
+ self.cam = CAMSimp(maxT=max_len + 1, num_channels=nchannel)
+ self.pre_lstm = nn.LSTM(nchannel,
+ int(nchannel / 2),
+ bidirectional=True)
+ self.rnn = nn.GRUCell(nchannel * 2, nchannel)
+ self.generator = nn.Sequential(nn.Dropout(p=dropout),
+ nn.Linear(nchannel, out_channels - 2))
+ self.char_embeddings = nn.Embedding(out_channels,
+ embedding_dim=in_channels,
+ padding_idx=out_channels - 1)
+
+ def forward(self, inputs, data=None):
+ A = self.cam(inputs)
+ if isinstance(inputs, list):
+ feature = inputs[-1]
+ else:
+ feature = inputs
+ nB, nC, nH, nW = feature.shape
+ nT = A.shape[1]
+ # Normalize
+ A = A / A.view(nB, nT, -1).sum(2).view(nB, nT, 1, 1)
+ # weighted sum
+ C = feature.view(nB, 1, nC, nH, nW) * A.view(nB, nT, 1, nH, nW)
+ C = C.view(nB, nT, nC, -1).sum(3).transpose(1, 0) # T, B, C
+ C, _ = self.pre_lstm(C) # T, B, C
+ C = F.dropout(C, p=0.3, training=self.training)
+ if self.training:
+ text = data[0]
+ text_length = data[-1]
+ nsteps = int(text_length.max())
+ gru_res = torch.zeros_like(C)
+ hidden = torch.zeros(nB, self.nchannel).type_as(C.data)
+ prev_emb = self.char_embeddings(text[:, 0])
+ for i in range(0, nsteps + 1):
+ hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1),
+ hidden)
+ gru_res[i, :, :] = hidden
+ prev_emb = self.char_embeddings(text[:, i + 1])
+ gru_res = self.generator(gru_res)
+ return gru_res[:nsteps + 1, :, :].transpose(1, 0)
+ else:
+ gru_res = torch.zeros_like(C)
+ hidden = torch.zeros(nB, self.nchannel).type_as(C.data)
+ prev_emb = self.char_embeddings(
+ torch.zeros(nB, dtype=torch.int64, device=feature.device) +
+ self.bos)
+ dec_seq = torch.full((nB, nT),
+ self.ignore_index,
+ dtype=torch.int64,
+ device=feature.get_device())
+
+ for i in range(0, nT):
+ hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1),
+ hidden)
+ gru_res[i, :, :] = hidden
+ mid_res = self.generator(hidden).argmax(-1)
+ dec_seq[:, i] = mid_res.squeeze(0)
+ if (dec_seq == self.eos).any(dim=-1).all():
+ break
+ prev_emb = self.char_embeddings(mid_res)
+ gru_res = self.generator(gru_res)
+ return F.softmax(gru_res.transpose(1, 0), -1)
diff --git a/openrec/modeling/decoders/igtr_decoder.py b/openrec/modeling/decoders/igtr_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..491e84bd91a27853bd302a4a9c4c14d8b7e9223e
--- /dev/null
+++ b/openrec/modeling/decoders/igtr_decoder.py
@@ -0,0 +1,815 @@
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.init import ones_, trunc_normal_, zeros_
+
+from openrec.modeling.common import DropPath, Identity, Mlp
+from openrec.modeling.decoders.nrtr_decoder import Embeddings
+
+
+class CrossAttention(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, q, kv, key_mask=None):
+ N, C = kv.shape[1:]
+ QN = q.shape[1]
+ q = self.q(q).reshape([-1, QN, self.num_heads,
+ C // self.num_heads]).transpose(1, 2)
+ q = q * self.scale
+ k, v = self.kv(kv).reshape(
+ [-1, N, 2, self.num_heads,
+ C // self.num_heads]).permute(2, 0, 3, 1, 4)
+
+ attn = q.matmul(k.transpose(2, 3))
+
+ if key_mask is not None:
+ attn = attn + key_mask.unsqueeze(1)
+
+ attn = F.softmax(attn, -1)
+ if not self.training:
+ self.attn_map = attn
+ attn = self.attn_drop(attn)
+
+ x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C))
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class EdgeDecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=[0.0, 0.0],
+ act_layer=nn.GELU,
+ norm_layer='nn.LayerNorm',
+ epsilon=1e-6,
+ ):
+ super().__init__()
+
+ self.head_dim = dim // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path1 = DropPath(
+ drop_path[0]) if drop_path[0] > 0.0 else Identity()
+ self.norm1 = eval(norm_layer)(dim, eps=epsilon)
+ self.norm2 = eval(norm_layer)(dim, eps=epsilon)
+
+ # self.c = nn.Linear(dim, dim*2)
+ self.p = nn.Linear(dim, dim)
+ self.cv = nn.Linear(dim, dim)
+ self.pv = nn.Linear(dim, dim)
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.p_proj = nn.Linear(dim, dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp_ratio = mlp_ratio
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, p, cv, pv):
+ pN = p.shape[1]
+ vN = cv.shape[1]
+ p_shortcut = p
+
+ p1 = self.p(p).reshape(
+ [-1, pN, self.num_heads,
+ self.dim // self.num_heads]).transpose(1, 2)
+ cv1 = self.cv(cv).reshape(
+ [-1, vN, self.num_heads,
+ self.dim // self.num_heads]).transpose(1, 2)
+ pv1 = self.pv(pv).reshape(
+ [-1, vN, self.num_heads,
+ self.dim // self.num_heads]).transpose(1, 2)
+
+ edge = F.softmax(p1.matmul(pv1.transpose(2, 3)), -1) # B h N N
+
+ p_c = (edge @ cv1).transpose(1, 2).reshape((-1, pN, self.dim))
+
+ x1 = self.norm1(p_shortcut + self.drop_path1(self.p_proj(p_c)))
+
+ x = self.norm2(x1 + self.drop_path1(self.mlp(x1)))
+ return x
+
+
+class DecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer='nn.LayerNorm',
+ epsilon=1e-6,
+ ):
+ super().__init__()
+ self.norm1 = eval(norm_layer)(dim, eps=epsilon)
+ self.normkv = eval(norm_layer)(dim, eps=epsilon)
+
+ self.mixer = CrossAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+
+ self.norm2 = eval(norm_layer)(dim, eps=epsilon)
+
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp_ratio = mlp_ratio
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, q, kv, key_mask=None):
+ x1 = q + self.drop_path(
+ self.mixer(self.norm1(q), self.normkv(kv), key_mask))
+ x = x1 + self.drop_path(self.mlp(self.norm2(x1)))
+ return x
+
+
+class CMFFLayer(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ epsilon=1e-6,
+ ):
+ super().__init__()
+ self.normq1 = nn.LayerNorm(dim, eps=epsilon)
+ self.normkv1 = nn.LayerNorm(dim, eps=epsilon)
+ self.images_to_question_cross_attn = CrossAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.normq2 = nn.LayerNorm(dim, eps=epsilon)
+ self.normkv2 = nn.LayerNorm(dim, eps=epsilon)
+ self.question_to_images_cross_attn = CrossAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ self.normmlp = nn.LayerNorm(dim, eps=epsilon)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, question_f, prompt_f, visual_f, mask=None):
+
+ query_add = torch.concat([question_f, prompt_f, visual_f], 1)
+
+ query_add = query_add + self.drop_path(
+ self.images_to_question_cross_attn(self.normq1(query_add),
+ self.normkv1(prompt_f), mask))
+ query_add = query_add + self.drop_path(
+ self.question_to_images_cross_attn(
+ self.normq2(query_add),
+ self.normkv2(query_add[:, -visual_f.shape[1]:, :])))
+ query_updated = query_add + self.drop_path(
+ self.mlp(self.normmlp(query_add)))
+
+ question_f_updated = query_updated[:, :question_f.shape[1], :]
+ prompt_f_updated = query_updated[:, question_f.
+ shape[1]:-visual_f.shape[1], :]
+ visual_f_updated = query_updated[:, -visual_f.shape[1]:, :]
+
+ return question_f_updated, prompt_f_updated, visual_f_updated
+
+
+class IGTRDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ dim,
+ out_channels,
+ num_layer=2,
+ drop_path_rate=0.1,
+ max_len=25,
+ vis_seq=50,
+ ch=False,
+ ar=False,
+ refine_iter=0,
+ quesall=True,
+ next_pred=False,
+ ds=False,
+ pos2d=False,
+ check_search=False,
+ max_size=[8, 32],
+ **kwargs):
+ super(IGTRDecoder, self).__init__()
+
+ self.out_channels = out_channels
+ self.dim = dim
+ self.max_len = max_len + 3 # max_len + eos + bos
+ self.ch = ch
+ self.char_embed = Embeddings(d_model=dim,
+ vocab=self.out_channels,
+ scale_embedding=True)
+ self.ignore_index = out_channels - 1
+ self.ar = ar
+ self.refine_iter = refine_iter
+ self.bos = self.out_channels - 2
+ self.eos = 0
+ self.next_pred = next_pred
+ self.quesall = quesall
+ self.check_search = check_search
+ dpr = np.linspace(0, drop_path_rate, num_layer + 2)
+
+ self.cmff_decoder = nn.ModuleList([
+ CMFFLayer(dim=dim,
+ num_heads=dim // 32,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=dpr[i]) for i in range(num_layer)
+ ])
+
+ self.answer_to_question_layer = DecoderLayer(dim=dim,
+ num_heads=dim // 32,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=dpr[-2])
+ self.answer_to_image_layer = DecoderLayer(dim=dim,
+ num_heads=dim // 32,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=dpr[-1])
+
+ self.char_pos_embed = nn.Parameter(torch.zeros([self.max_len, dim],
+ dtype=torch.float32),
+ requires_grad=True)
+ self.appear_num_embed = nn.Parameter(torch.zeros([self.max_len, dim],
+ dtype=torch.float32),
+ requires_grad=True)
+ self.ds = ds
+ self.pos2d = pos2d
+ if not ds:
+ self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim],
+ dtype=torch.float32),
+ requires_grad=True)
+ trunc_normal_(self.vis_pos_embed, std=0.02)
+ elif pos2d:
+ pos_embed = torch.zeros([1, max_size[0] * max_size[1], dim],
+ dtype=torch.float32)
+ trunc_normal_(pos_embed, mean=0, std=0.02)
+ self.vis_pos_embed = nn.Parameter(
+ pos_embed.transpose(1, 2).reshape(1, dim, max_size[0],
+ max_size[1]),
+ requires_grad=True,
+ )
+ self.prompt_pos_embed = nn.Parameter(torch.zeros([1, 6, dim],
+ dtype=torch.float32),
+ requires_grad=True)
+
+ self.answer_query = nn.Parameter(torch.zeros([1, 1, dim],
+ dtype=torch.float32),
+ requires_grad=True)
+ self.norm_pred = nn.LayerNorm(dim, eps=1e-6)
+ self.ques1_head = nn.Linear(dim, self.out_channels - 2)
+ self.ques2_head = nn.Linear(dim, self.max_len, bias=False)
+ self.ques3_head = nn.Linear(dim, self.max_len - 1)
+ self.ques4_head = nn.Linear(dim, self.max_len - 1)
+ trunc_normal_(self.char_pos_embed, std=0.02)
+ trunc_normal_(self.appear_num_embed, std=0.02)
+ trunc_normal_(self.answer_query, std=0.02)
+ trunc_normal_(self.prompt_pos_embed, std=0.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {
+ 'char_pos_embed', 'vis_pos_embed', 'appear_num_embed',
+ 'answer_query', 'char_embed'
+ }
+
+ def question_encoder(self, targets, train_i):
+ (
+ prompt_pos_idx,
+ prompt_char_idx,
+ ques_pos_idx,
+ ques1_answer,
+ ques2_char_idx,
+ ques2_answer,
+ ques4_char_num,
+ ques_len,
+ ques2_len,
+ prompt_len,
+ ) = targets
+ max_ques_len = torch.max(ques_len)
+ max_ques2_len = torch.max(ques2_len)
+ max_prompt_len = torch.max(prompt_len)
+ if self.next_pred and (train_i == 2 or train_i == 3):
+ prompt_pos = self.prompt_pos_embed
+ prompt_char_idx = prompt_char_idx[:, :max_prompt_len]
+ else:
+ prompt_pos = F.embedding(
+ prompt_pos_idx[:, :max_prompt_len], self.char_pos_embed
+ ) # bs lp [ 0, 4, 3, 12, 12, 12, 12, 12, 12, 12, 12]
+ prompt_char_idx = prompt_char_idx[:, :max_prompt_len]
+ prompt_char = self.char_embed(prompt_char_idx) # bs lp
+
+ prompt = prompt_pos + prompt_char
+ mask_1234 = torch.where(prompt_char_idx == self.ignore_index,
+ float('-inf'), 0)
+
+ ques1 = F.embedding(ques_pos_idx[:, :max_ques_len],
+ self.char_pos_embed) # bs lq1 dim
+ ques1_answer = ques1_answer[:, :max_ques_len]
+ if self.quesall or train_i == 0:
+ ques2_char = self.char_embed(ques2_char_idx[:, :max_ques2_len, 1])
+ ques2 = ques2_char + F.embedding(ques2_char_idx[:, :max_ques2_len,
+ 0],
+ self.char_pos_embed) # bs lq2 dim
+ ques2_answer = ques2_answer[:, :max_ques2_len]
+ ques2_head = F.embedding(ques2_char_idx[:, :max_ques2_len, 0],
+ self.ques2_head.weight)
+ ques4_char = self.char_embed(ques1_answer)
+ ques4_ap_num = F.embedding(ques4_char_num[:, :max_ques_len],
+ self.appear_num_embed)
+ ques4 = ques4_char + ques4_ap_num
+ ques4_answer = ques_pos_idx[:, :max_ques_len]
+
+ return (
+ prompt,
+ ques1,
+ ques2,
+ ques2_head,
+ ques4,
+ ques1_answer,
+ ques2_answer,
+ ques4_answer,
+ mask_1234.unsqueeze(1),
+ )
+ else:
+ return prompt, ques1, ques1_answer, mask_1234.unsqueeze(1)
+
+ def forward(self, x, data=None):
+ if self.training:
+ return self.forward_train(x, data)
+ else:
+ return self.forward_test(x)
+
+ def forward_test(self, x):
+ if not self.ds:
+ visual_f = x + self.vis_pos_embed
+ elif self.pos2d:
+ x = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
+ visual_f = x.flatten(2).transpose(1, 2)
+ else:
+ visual_f = x
+ bs = x.shape[0]
+ prompt_bos = self.char_embed(
+ torch.full(
+ [bs, 1], self.bos, dtype=torch.long,
+ device=x.get_device())) + self.char_pos_embed[:1, :].unsqueeze(
+ 0) # BOS prompt
+ ques_all = torch.tile(self.char_pos_embed.unsqueeze(0), (bs, 1, 1))
+ if not self.ar:
+ if self.check_search:
+ tgt_in = torch.full((bs, self.max_len),
+ self.ignore_index,
+ dtype=torch.long,
+ device=x.get_device())
+ tgt_in[:, 0] = self.bos
+ logits = []
+ for j in range(1, self.max_len):
+ visual_f_check = visual_f
+ ques_check_i = ques_all[:, j:j + 1, :] + self.char_embed(
+ torch.arange(self.out_channels - 2,
+ device=x.get_device())).unsqueeze(0)
+ prompt_check = ques_all[:, :j] + self.char_embed(
+ tgt_in[:, :j])
+ # prompt_check = prompt_bos
+ mask = torch.where(
+ (tgt_in[:, :j] == self.eos).int().cumsum(-1) > 0,
+ float('-inf'), 0)
+ for layer in self.cmff_decoder:
+ ques_check_i, prompt_check, visual_f_check = layer(
+ ques_check_i, prompt_check, visual_f_check,
+ mask.unsqueeze(1))
+ answer_query_i = self.answer_to_question_layer(
+ ques_check_i, prompt_check, mask.unsqueeze(1))
+ answer_pred_i = self.norm_pred(
+ self.answer_to_image_layer(
+ answer_query_i, visual_f_check)) # B, 26, 37
+ # the next token probability is in the output's ith token position
+ fc_2 = self.ques2_head.weight[j:j + 1].unsqueeze(0)
+ fc_2 = fc_2.tile([bs, 1, 1])
+ p_i = fc_2 @ answer_pred_i.transpose(1, 2)
+ # p_i = p_i[:, 0, :]
+ logits.append(p_i)
+ if j < self.max_len - 1:
+ # greedy decode. add the next token index to the target input
+ tgt_in[:, j] = p_i.squeeze().argmax(-1)
+
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
+ if (tgt_in == self.eos).any(dim=-1).all():
+ break
+ logits = torch.cat(logits, dim=1)
+ else:
+ ques_pd = ques_all[:, 1:, :]
+ prompt_pd = prompt_bos
+ visual_f_pd = visual_f
+ for layer in self.cmff_decoder:
+ ques_pd, prompt_pd, visual_f_pd = layer(
+ ques_pd, prompt_pd, visual_f_pd)
+ answer_query_pd = self.answer_to_question_layer(
+ ques_pd, prompt_pd)
+ answer_feats_pd = self.norm_pred(
+ self.answer_to_image_layer(answer_query_pd,
+ visual_f_pd)) # B, 26, 37
+ logits = self.ques1_head(answer_feats_pd)
+ elif self.next_pred:
+ ques_pd_1 = ques_all[:, 1:2, :]
+ prompt_pd = prompt_bos
+ visual_f_pd = visual_f
+ for layer in self.cmff_decoder:
+ ques_pd_1, prompt_pd, visual_f_pd = layer(
+ ques_pd_1, prompt_pd, visual_f_pd)
+ answer_query_pd = self.answer_to_question_layer(
+ ques_pd_1, prompt_pd)
+ answer_feats_pd = self.norm_pred(
+ self.answer_to_image_layer(answer_query_pd,
+ visual_f_pd)) # B, 26, 37
+ logits_pd_1 = self.ques1_head(answer_feats_pd)
+
+ ques_next = self.char_pos_embed[-2:-1, :].unsqueeze(0).tile(
+ [bs, 1, 1])
+ prompt_next_bos = (self.char_embed(
+ torch.full(
+ [bs, 1], self.bos, dtype=torch.long,
+ device=x.get_device())) + self.prompt_pos_embed[:, :1, :])
+ pred_prob, pred_id = F.softmax(logits_pd_1, -1).max(-1)
+ pred_prob_list = [pred_prob]
+ pred_id_list = [pred_id]
+ for j in range(1, 70):
+ prompt_next_1 = self.char_embed(
+ pred_id) + self.prompt_pos_embed[:,
+ -1 * pred_id.shape[1]:, :]
+ prompt_next = torch.concat([prompt_next_bos, prompt_next_1], 1)
+ ques_next_i = ques_next
+ visual_f_i = visual_f
+ for layer in self.cmff_decoder:
+ ques_next_i, prompt_next, visual_f_pd = layer(
+ ques_next_i, prompt_next, visual_f_i)
+ answer_query_next_i = self.answer_to_question_layer(
+ ques_next_i, prompt_next)
+ answer_feats_next_i = self.norm_pred(
+ self.answer_to_image_layer(answer_query_next_i,
+ visual_f_i)) # B, 26, 37
+ logits_next_i = self.ques1_head(answer_feats_next_i)
+ # pred_id = logits_next_i.argmax(-1)
+ pred_prob_i, pred_id_i = F.softmax(logits_next_i, -1).max(-1)
+ pred_prob_list.append(pred_prob_i)
+ pred_id_list.append(pred_id_i)
+ if (torch.concat(pred_id_list,
+ 1) == self.eos).any(dim=-1).all():
+ break
+ if pred_id.shape[1] >= 5:
+ pred_id = torch.concat([pred_id[:, 1:], pred_id_i], 1)
+ else:
+ pred_id = torch.concat([pred_id, pred_id_i], 1)
+ return [
+ torch.concat(pred_id_list, 1),
+ torch.concat(pred_prob_list, 1)
+ ]
+
+ else:
+ tgt_in = torch.full((bs, self.max_len),
+ self.ignore_index,
+ dtype=torch.long,
+ device=x.get_device())
+ tgt_in[:, 0] = self.bos
+ logits = []
+ for j in range(1, self.max_len):
+ visual_f_ar = visual_f
+ ques_i = ques_all[:, j:j + 1, :]
+ prompt_ar = ques_all[:, :j] + self.char_embed(tgt_in[:, :j])
+ mask = torch.where(
+ (tgt_in[:, :j] == self.eos).int().cumsum(-1) > 0,
+ float('-inf'), 0)
+ for layer in self.cmff_decoder:
+ ques_i, prompt_ar, visual_f_ar = layer(
+ ques_i, prompt_ar, visual_f_ar, mask.unsqueeze(1))
+ answer_query_i = self.answer_to_question_layer(
+ ques_i, prompt_ar, mask.unsqueeze(1))
+ answer_pred_i = self.norm_pred(
+ self.answer_to_image_layer(answer_query_i,
+ visual_f_ar)) # B, 26, 37
+ # the next token probability is in the output's ith token position
+ p_i = self.ques1_head(answer_pred_i)
+ logits.append(p_i)
+ if j < self.max_len - 1:
+ # greedy decode. add the next token index to the target input
+ tgt_in[:, j] = p_i.squeeze().argmax(-1)
+
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
+ if (tgt_in == self.eos).any(dim=-1).all():
+ break
+ logits = torch.cat(logits, dim=1)
+
+ if self.refine_iter > 0:
+ pred_probs, pred_idxs = F.softmax(logits, -1).max(-1)
+ for i in range(self.refine_iter):
+
+ mask_check = (pred_idxs == self.eos).int().cumsum(-1) <= 1
+
+ ques_check_all = self.char_embed(
+ pred_idxs) + ques_all[:, 1:pred_idxs.shape[1] + 1, :]
+ prompt_check = prompt_bos
+ visual_f_check = visual_f
+ ques_check = ques_check_all
+ for layer in self.cmff_decoder:
+ ques_check, prompt_check, visual_f_check = layer(
+ ques_check, prompt_check, visual_f_check)
+ answer_query_check = self.answer_to_question_layer(
+ ques_check, prompt_check)
+ answer_pred_check = self.norm_pred(
+ self.answer_to_image_layer(answer_query_check,
+ visual_f_check)) # B, 26, 37
+ ques2_head = self.ques2_head.weight[1:pred_idxs.shape[1] +
+ 1, :]
+ ques2_head = torch.tile(ques2_head.unsqueeze(0), [bs, 1, 1])
+ answer2_pred = answer_pred_check.matmul(
+ ques2_head.transpose(1, 2))
+ diag_mask = torch.eye(answer2_pred.shape[1],
+ device=x.get_device()).unsqueeze(0).tile(
+ [bs, 1, 1])
+ answer2_pred = F.sigmoid(
+ (answer2_pred * diag_mask).sum(-1)) * mask_check
+
+ check_result = answer2_pred < 0.9 # pred_probs < 0.99
+
+ prompt_refine = torch.concat([prompt_bos, ques_check_all], 1)
+ mask_refine = torch.where(
+ check_result, float('-inf'), 0) + torch.where(
+ (pred_idxs == self.eos).int().cumsum(-1) < 1, 0,
+ float('-inf'))
+ mask_refine = torch.concat(
+ [torch.zeros([bs, 1], device=x.get_device()), mask_refine],
+ 1).unsqueeze(1)
+ ques_refine = ques_all[:, 1:pred_idxs.shape[1] + 1, :]
+ visual_f_refine = visual_f
+ for layer in self.cmff_decoder:
+ ques_refine, prompt_refine, visual_f_refine = layer(
+ ques_refine, prompt_refine, visual_f_refine,
+ mask_refine)
+ answer_query_refine = self.answer_to_question_layer(
+ ques_refine, prompt_refine, mask_refine)
+ answer_pred_refine = self.norm_pred(
+ self.answer_to_image_layer(answer_query_refine,
+ visual_f_refine)) # B, 26, 37
+ answer_refine = self.ques1_head(answer_pred_refine)
+ refine_probs, refine_idxs = F.softmax(answer_refine,
+ -1).max(-1)
+ pred_idxs_refine = torch.where(check_result, refine_idxs,
+ pred_idxs)
+ pred_idxs = torch.where(mask_check, pred_idxs_refine,
+ pred_idxs)
+ pred_probs_refine = torch.where(check_result, refine_probs,
+ pred_probs)
+ pred_probs = torch.where(mask_check, pred_probs_refine,
+ pred_probs)
+
+ return [pred_idxs, pred_probs]
+
+ return F.softmax(logits, -1)
+
+ def forward_train(self, x, targets=None):
+
+ bs = x.shape[0]
+ answer_token = torch.tile(self.answer_query, (bs, 1, 1))
+ if self.ch:
+ ques3 = self.char_embed(targets[7][:, :,
+ 0]) + answer_token # bs nc dim
+ ques3_answer = targets[7][:, :, 1]
+ else:
+ ques3 = self.char_embed(
+ torch.arange(self.out_channels - 2, device=x.get_device())
+ ).unsqueeze(0) + answer_token # bs nc dim
+ ques3_answer = targets[7]
+ loss1_list = []
+ loss2_list = []
+ loss3_list = []
+ loss4_list = []
+ sampler1_num = 0
+ sampler2_num = 0
+ sampler3_num = 0
+ sampler4_num = 0
+ if not self.ds:
+ visual_f = x + self.vis_pos_embed
+ elif self.pos2d:
+ x = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
+ visual_f = x.flatten(2).transpose(1, 2)
+ else:
+ visual_f = x
+ train_i = 0
+ for target_ in zip(
+ targets[1].transpose(0, 1),
+ targets[2].transpose(0, 1),
+ targets[3].transpose(0, 1),
+ targets[4].transpose(0, 1),
+ targets[5].transpose(0, 1),
+ targets[6].transpose(0, 1),
+ targets[8].transpose(0, 1),
+ targets[9].transpose(0, 1),
+ targets[10].transpose(0, 1),
+ targets[11].transpose(0, 1),
+ ):
+ # target_ = [prompt_pos_idx, prompt_char_idx, ques_pos_idx, ques1_answer, \
+ # ques2_char_idx, ques2_answer, ques4_char_num, ques_len, prompt_len]
+ visual_f_1234 = visual_f
+ if self.quesall or train_i == 0:
+ (
+ prompt,
+ ques1,
+ ques2,
+ ques2_head,
+ ques4,
+ ques1_answer,
+ ques2_answer,
+ ques4_answer,
+ mask_1234,
+ ) = self.question_encoder(target_, train_i)
+ prompt_1234 = prompt
+ ques_1234 = torch.concat([ques1, ques2, ques3, ques4], 1)
+ for layer in self.cmff_decoder:
+ ques_1234, prompt_1234, visual_f_1234 = layer(
+ ques_1234, prompt_1234, visual_f_1234, mask_1234)
+ answer_query_1234 = self.answer_to_question_layer(
+ ques_1234, prompt_1234, mask_1234)
+ answer_feats_1234 = self.norm_pred(
+ self.answer_to_image_layer(answer_query_1234,
+ visual_f_1234)) # B, 26, 37
+
+ answer_feats_1 = answer_feats_1234[:, :ques1.shape[1], :]
+ answer_feats_2 = answer_feats_1234[:, ques1.shape[1]:(
+ ques1.shape[1] + ques2.shape[1]), :]
+ answer_feats_3 = answer_feats_1234[:, (
+ ques1.shape[1] + ques2.shape[1]):-ques4.shape[1], :]
+ answer_feats_4 = answer_feats_1234[:, -ques4.shape[1]:, :]
+
+ answer1_pred = self.ques1_head(answer_feats_1)
+ if train_i == 0:
+ logits = answer1_pred
+
+ n = (ques1_answer != self.ignore_index).sum().item()
+ loss1 = n * F.cross_entropy(
+ answer1_pred.flatten(0, 1),
+ ques1_answer.flatten(0, 1),
+ ignore_index=self.ignore_index,
+ reduction='mean',
+ )
+ sampler1_num += n
+ loss1_list.append(loss1)
+
+ answer2_pred = answer_feats_2.matmul(ques2_head.transpose(
+ 1, 2))
+ diag_mask = torch.eye(answer2_pred.shape[1],
+ device=x.get_device()).unsqueeze(0).tile(
+ [bs, 1, 1])
+ answer2_pred = (answer2_pred * diag_mask).sum(-1)
+
+ ques2_answer = ques2_answer.flatten(0, 1)
+ non_pad_mask = torch.not_equal(ques2_answer, self.ignore_index)
+ n = non_pad_mask.sum().item()
+ ques2_answer = torch.where(ques2_answer == self.ignore_index,
+ 0, ques2_answer)
+ loss2_none = F.binary_cross_entropy_with_logits(
+ answer2_pred.flatten(0, 1), ques2_answer, reduction='none')
+ loss2 = n * loss2_none.masked_select(non_pad_mask).mean()
+ sampler2_num += n
+ loss2_list.append(loss2)
+
+ answer3_pred = self.ques3_head(answer_feats_3)
+ n = (ques3_answer != self.ignore_index).sum().item()
+ loss3 = n * F.cross_entropy(answer3_pred.flatten(0, 1),
+ ques3_answer.flatten(0, 1),
+ reduction='mean')
+ sampler3_num += n
+ loss3_list.append(loss3)
+
+ answer4_pred = self.ques4_head(answer_feats_4)
+ n = (ques4_answer != self.max_len - 1).sum().item()
+ loss4 = n * F.cross_entropy(
+ answer4_pred.flatten(0, 1),
+ ques4_answer.flatten(0, 1),
+ ignore_index=self.max_len - 1,
+ reduction='mean',
+ )
+ sampler4_num += n
+ loss4_list.append(loss4)
+ else:
+ prompt, ques1, ques1_answer, mask_1234 = self.question_encoder(
+ target_, train_i)
+ prompt_1234 = prompt
+ for layer in self.cmff_decoder:
+ ques1, prompt_1234, visual_f_1234 = layer(
+ ques1, prompt_1234, visual_f_1234, mask_1234)
+ answer_query_1 = self.answer_to_question_layer(
+ ques1, prompt_1234, mask_1234)
+ answer_feats_1 = self.norm_pred(
+ self.answer_to_image_layer(answer_query_1,
+ visual_f_1234)) # B, 26, 37
+ answer1_pred = self.ques1_head(answer_feats_1)
+ n = (ques1_answer != self.ignore_index).sum().item()
+ loss1 = n * F.cross_entropy(
+ answer1_pred.flatten(0, 1),
+ ques1_answer.flatten(0, 1),
+ ignore_index=self.ignore_index,
+ reduction='mean',
+ )
+ sampler1_num += n
+ loss1_list.append(loss1)
+ train_i += 1
+
+ loss_list = [
+ sum(loss1_list) / sampler1_num,
+ sum(loss2_list) / sampler2_num,
+ sum(loss3_list) / sampler3_num,
+ sum(loss4_list) / sampler4_num,
+ ]
+ loss = {
+ 'loss': sum(loss_list),
+ 'loss1': loss_list[0],
+ 'loss2': loss_list[1],
+ 'loss3': loss_list[2],
+ 'loss4': loss_list[3],
+ }
+ return [loss, logits]
diff --git a/openrec/modeling/decoders/lister_decoder.py b/openrec/modeling/decoders/lister_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a5d8c6bd610bead4e0b270ffa18087dafe792a1
--- /dev/null
+++ b/openrec/modeling/decoders/lister_decoder.py
@@ -0,0 +1,535 @@
+"""This code is refer from:
+https://github.com/AlibabaResearch/AdvancedLiterateMachinery/blob/main/OCR/LISTER
+"""
+
+# Copyright (2023) Alibaba Group and its affiliates
+# --------------------------------------------------------
+# To decode arbitrary-length text images.
+# --------------------------------------------------------
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.init import trunc_normal_
+
+from openrec.modeling.encoders.focalsvtr import FocalNetBlock
+
+
+class LocalSelfAttention(nn.Module):
+
+ def __init__(self,
+ feat_dim,
+ nhead,
+ window_size: int,
+ add_pos_bias=False,
+ qkv_drop=0.0,
+ proj_drop=0.0,
+ mlm=False):
+ super().__init__()
+ assert feat_dim % nhead == 0
+ self.q_fc = nn.Linear(feat_dim, feat_dim)
+ self.kv_fc = nn.Linear(feat_dim, feat_dim * 2)
+
+ self.nhead = nhead
+ self.head_dim = feat_dim // nhead
+ self.window_size = window_size
+ if add_pos_bias:
+ self.kv_pos_bias = nn.Parameter(torch.zeros(window_size, feat_dim))
+ trunc_normal_(self.kv_pos_bias, std=.02)
+ else:
+ self.kv_pos_bias = None
+ self.qkv_dropout = nn.Dropout(qkv_drop)
+
+ self.proj = nn.Linear(feat_dim, feat_dim)
+ self.proj_dropout = nn.Dropout(proj_drop)
+ self.mlm = mlm
+ if mlm:
+ print('Use mlm.')
+
+ def _gen_t_index(self, real_len, device):
+ idx = torch.stack([
+ torch.arange(real_len, dtype=torch.long, device=device) + st
+ for st in range(self.window_size)
+ ]).t() # [T, w]
+ return idx
+
+ def _apply_attn_mask(self, attn_score):
+ attn_score[:, :, :, :, self.window_size // 2] = float('-inf')
+ return attn_score
+
+ def forward(self, x, mask):
+ """
+ Args:
+ x: [b, T, C]
+ mask: [b, T]
+ """
+ b, T, C = x.size()
+ # mask with 0
+ x = x * mask.unsqueeze(-1)
+
+ q = self.q_fc(self.qkv_dropout(x)) # [b, T, C]
+ pad_l = pad_r = self.window_size // 2
+ x_pad = F.pad(x, (0, 0, pad_l, pad_r)) # [b, T+w, C]
+ # organize the window-based kv
+ b_idx = torch.arange(b, dtype=torch.long,
+ device=x.device).contiguous().view(b, 1, 1)
+ t_idx = self._gen_t_index(T, x.device).unsqueeze(0)
+ x_pad = x_pad[b_idx, t_idx] # [b, T, w, C]
+ if self.kv_pos_bias is not None:
+ x_pad = self.qkv_dropout(
+ x_pad + self.kv_pos_bias.unsqueeze(0).unsqueeze(1))
+ else:
+ x_pad = self.qkv_dropout(x_pad)
+ kv = self.kv_fc(x_pad) # [b, T, w, 2*C]
+ k, v = kv.chunk(2, -1) # both are [b, T, w, C]
+ # multi-head splitting
+ q = q.contiguous().view(b, T, self.nhead, -1) # [b, T, h, C/h]
+ k = k.contiguous().view(b, T, self.window_size, self.nhead,
+ -1).transpose(2, 3) # [b, T, h, w, C/h]
+ v = v.contiguous().view(b, T, self.window_size, self.nhead,
+ -1).transpose(2, 3)
+ # calculate attention scores
+ # the scaling of qk refers to: https://kexue.fm/archives/8823
+ alpha = q.unsqueeze(3).matmul(
+ k.transpose(-1, -2) / self.head_dim *
+ math.log(self.window_size)) # [b, T, h, 1, w]
+ if self.mlm:
+ alpha = self._apply_attn_mask(alpha)
+ alpha = alpha.softmax(-1)
+ output = alpha.matmul(v).squeeze(-2).contiguous().view(b, T,
+ -1) # [b, T, C]
+ output = self.proj_dropout(self.proj(output))
+ output = output * mask.unsqueeze(-1)
+ return output
+
+
+class LocalAttentionBlock(nn.Module):
+
+ def __init__(self,
+ feat_dim,
+ nhead,
+ window_size,
+ add_pos_bias: bool,
+ drop=0.0,
+ proj_drop=0.0,
+ init_values=1e-6,
+ mlm=False):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(feat_dim)
+ self.sa = LocalSelfAttention(feat_dim,
+ nhead,
+ window_size,
+ add_pos_bias,
+ drop,
+ proj_drop,
+ mlm=mlm)
+ self.norm2 = nn.LayerNorm(feat_dim)
+ self.mlp = nn.Sequential(
+ nn.Linear(feat_dim, feat_dim * 4),
+ nn.GELU(),
+ nn.Dropout(drop),
+ nn.Linear(feat_dim * 4, feat_dim),
+ nn.Dropout(drop),
+ )
+ if init_values > 0:
+ self.gamma_1 = nn.Parameter(init_values * torch.ones(feat_dim),
+ requires_grad=True)
+ self.gamma_2 = nn.Parameter(init_values * torch.ones(feat_dim),
+ requires_grad=True)
+ else:
+ self.gamma_1, self.gamma_2 = 1.0, 1.0
+
+ def forward(self, x, mask):
+ x = x + self.gamma_1 * self.sa(self.norm1(x), mask)
+ x = x + self.gamma_2 * self.mlp(self.norm2(x))
+ x = x * mask.unsqueeze(-1)
+ return x
+
+
+class LocalAttentionModule(nn.Module):
+
+ def __init__(self,
+ feat_dim,
+ nhead,
+ window_size,
+ num_layers,
+ drop_rate=0.0,
+ proj_drop_rate=0.0,
+ detach_grad=False,
+ mlm=False):
+ super().__init__()
+ self.attn_blocks = nn.ModuleList([
+ LocalAttentionBlock(
+ feat_dim,
+ nhead,
+ window_size,
+ add_pos_bias=(i == 0),
+ drop=drop_rate,
+ proj_drop=proj_drop_rate,
+ mlm=mlm,
+ ) for i in range(num_layers)
+ ])
+
+ self.detach_grad = detach_grad
+
+ def forward(self, x, mask):
+ if self.detach_grad:
+ x = x.detach()
+ for blk in self.attn_blocks:
+ x = blk(x, mask)
+ return x
+
+
+def softmax_m1(x: torch.Tensor, dim: int):
+ # for x >= 0
+ fx = x.exp() - 1
+ fx = fx / fx.sum(dim, keepdim=True)
+ return fx
+
+
+class BilinearLayer(nn.Module):
+
+ def __init__(self, in1, in2, out, bias=True):
+ super(BilinearLayer, self).__init__()
+ self.weight = nn.Parameter(torch.randn(out, in1, in2))
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out))
+ else:
+ self.bias = None
+ torch.nn.init.xavier_normal_(self.weight, 0.1)
+
+ def forward(self, x1, x2):
+ '''
+ input:
+ x1: [b, T1, in1]
+ x2: [b, T2, in2]
+ output:
+ y: [b, T1, T2, out]
+ '''
+ y = torch.einsum('bim,omn->bino', x1, self.weight) # [b, T1, in2, out]
+ y = torch.einsum('bino,bjn->bijo', y, x2) # [b, T1, T2, out]
+ if self.bias is not None:
+ y = y + self.bias.contiguous().view(1, 1, 1, -1)
+ return y
+
+
+class NeighborDecoder(nn.Module):
+ """Find neighbors for each character In this version, each iteration shares
+ the same decoder with the local vision decoder."""
+
+ def __init__(self,
+ num_classes,
+ feat_dim,
+ max_len=1000,
+ detach_grad=False,
+ **kwargs):
+ super().__init__()
+ self.eos_emb = nn.Parameter(torch.ones(feat_dim))
+ trunc_normal_(self.eos_emb, std=.02)
+ self.q_fc = nn.Linear(feat_dim, feat_dim, bias=True)
+ self.k_fc = nn.Linear(feat_dim, feat_dim)
+
+ self.neighbor_navigator = BilinearLayer(feat_dim, feat_dim, 1)
+
+ self.vis_cls = nn.Linear(feat_dim, num_classes)
+
+ self.p_threshold = 0.6
+ self.max_len = max_len or 1000 # to avoid endless loop
+ self.max_ch = max_len or 1000
+
+ self.detach_grad = detach_grad
+ self.attn_scaling = kwargs['attn_scaling']
+
+ def align_chars(self, start_map, nb_map, max_ch=None):
+ if self.training:
+ assert max_ch is not None
+ max_ch = max_ch or self.max_ch # required during training to be efficient
+ b, N = nb_map.shape[:2]
+
+ char_map = start_map # [b, N]
+ all_finished = torch.zeros(b, dtype=torch.long, device=nb_map.device)
+ char_maps = []
+ char_masks = []
+ for i in range(max_ch):
+ char_maps.append(char_map)
+ char_mask = (all_finished == 0).float()
+ char_masks.append(char_mask)
+ if i == max_ch - 1:
+ break
+ all_finished = all_finished + (char_map[:, -1] >
+ self.p_threshold).long()
+ if not self.training:
+ # check if end
+ if (all_finished > 0).sum().item() == b:
+ break
+ if self.training:
+ char_map = char_map.unsqueeze(1).matmul(nb_map).squeeze(1)
+ else:
+ # char_map_dt = (char_map.detach() * 50).softmax(-1)
+ k = min(1 + i * 2, 16)
+ char_map_dt = softmax_m1(char_map.detach() * k, dim=-1)
+ char_map = char_map_dt.unsqueeze(1).matmul(nb_map).squeeze(1)
+
+ char_maps = torch.stack(char_maps, dim=1) # [b, L, N], L = n_char + 1
+ char_masks = torch.stack(char_masks, dim=1) # [b, L], 0 denotes masked
+ return char_maps, char_masks
+
+ def forward(self, x: torch.FloatTensor, max_char: int = None):
+ b, c, h, w = x.size()
+ x = x.flatten(2).transpose(1, 2) # [b, N, c], N = h x w
+ g = x.mean(1) # global representation, [b, c]
+
+ # append eos emb to x
+ x_ext = torch.cat(
+ [x, self.eos_emb.unsqueeze(0).expand(b, -1).unsqueeze(1)],
+ dim=1) # [b, N+1, c]
+
+ # locate the first character feature
+ q_start = self.q_fc(g) # [b, c]
+ k_feat = self.k_fc(x_ext) # [b, N+1, c]
+ start_map = k_feat.matmul(q_start.unsqueeze(-1)).squeeze(
+ -1) # [b, N+1]
+ # scaling, referring to: https://kexue.fm/archives/8823
+ if self.attn_scaling:
+ start_map = start_map / (c**0.5)
+ start_map = start_map.softmax(1)
+
+ # Neighbor discovering
+ q_feat = self.q_fc(x)
+ nb_map = self.neighbor_navigator(q_feat,
+ k_feat).squeeze(-1) # [b, N, N+1]
+ if self.attn_scaling:
+ nb_map = nb_map / (c**0.5)
+ nb_map = nb_map.softmax(2)
+ last_neighbor = torch.zeros(h * w + 1, device=x.device)
+ last_neighbor[-1] = 1.0
+ nb_map = torch.cat(
+ [
+ nb_map,
+ last_neighbor.contiguous().view(1, 1, -1).expand(b, -1, -1)
+ ],
+ dim=1) # to complete the neighbor matrix, (N+1) x (N+1)
+
+ # string (feature) decoding
+ char_maps, char_masks = self.align_chars(start_map, nb_map, max_char)
+ char_feats = char_maps.matmul(x_ext) # [b, L, c]
+ char_feats = char_feats * char_masks.unsqueeze(-1)
+ logits = self.vis_cls(char_feats) # [b, L, nC]
+
+ results = dict(
+ logits=logits,
+ char_feats=char_feats,
+ char_maps=char_maps,
+ char_masks=char_masks,
+ h=h,
+ nb_map=nb_map,
+ )
+ return results
+
+
+class FeatureMapEnhancer(nn.Module):
+ """ Merge the global and local features
+ """
+
+ def __init__(self,
+ feat_dim,
+ num_layers=1,
+ focal_level=3,
+ max_kh=1,
+ layerscale_value=1e-6,
+ drop_rate=0.0):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(feat_dim)
+ self.merge_layer = nn.ModuleList([
+ FocalNetBlock(
+ dim=feat_dim,
+ mlp_ratio=4,
+ drop=drop_rate,
+ focal_level=focal_level,
+ max_kh=max_kh,
+ focal_window=3,
+ use_layerscale=True,
+ layerscale_value=layerscale_value,
+ ) for i in range(num_layers)
+ ])
+ # self.scale = 1. / (feat_dim ** 0.5)
+ self.norm2 = nn.LayerNorm(feat_dim)
+ self.dropout = nn.Dropout(drop_rate)
+
+ def forward(self, feat_map, feat_char, char_attn_map):
+ """
+ feat_map: [b, N, C]
+ feat_char: [b, T, C], T include the EOS token
+ char_attn_map: [b, T, N], N exclude the EOS token
+ vis_mask: [b, N]
+ h: height of the feature map
+
+ return: [b, C, h, w]
+ """
+ b, C, h, w = feat_map.size()
+ feat_map = feat_map.flatten(2).transpose(1, 2)
+ # 1. restore the char feats into the visual map
+ # char_feat_map = char_attn_map.transpose(1, 2).matmul(feat_char * self.scale) # [b, N, C]
+ char_feat_map = char_attn_map.transpose(1, 2).matmul(
+ feat_char) # [b, N, C]
+ char_feat_map = self.norm1(char_feat_map)
+ feat_map = feat_map + char_feat_map
+
+ # 2. merge
+ # vis_mask = vis_mask.contiguous().view(b, h, -1) # [b, h, w]
+ for blk in self.merge_layer:
+ blk.H, blk.W = h, w
+ feat_map = blk(feat_map)
+ feat_map = self.dropout(self.norm2(feat_map))
+ feat_map = feat_map.transpose(1, 2).reshape(b, C, h, w) # [b, C, h, w]
+ # feat_map = feat_map * vis_mask.unsqueeze(1)
+ return feat_map
+
+
+class LISTERDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ max_len=25,
+ use_fem=True,
+ detach_grad=False,
+ nhead=8,
+ window_size=11,
+ iters=2,
+ num_sa_layers=1,
+ num_mg_layers=1,
+ coef=[1.0, 0.01, 0.001],
+ **kwargs):
+ super().__init__()
+ num_classes = out_channels - 1
+ self.ignore_index = num_classes
+ self.max_len = max_len
+ self.use_fem = use_fem
+ self.detach_grad = detach_grad
+ self.iters = max(1, iters) if use_fem else 0
+ feat_dim = in_channels
+ self.decoder = NeighborDecoder(num_classes,
+ feat_dim,
+ max_len=max_len,
+ detach_grad=detach_grad,
+ **kwargs)
+ if iters > 0 and use_fem:
+ self.cntx_module = LocalAttentionModule(feat_dim,
+ nhead,
+ window_size,
+ num_sa_layers,
+ drop_rate=0.1,
+ proj_drop_rate=0.1,
+ detach_grad=detach_grad,
+ mlm=kwargs.get(
+ 'mlm', False))
+ self.merge_layer = FeatureMapEnhancer(feat_dim,
+ num_layers=num_mg_layers)
+ self.celoss_fn = nn.CrossEntropyLoss(reduction='mean',
+ ignore_index=self.ignore_index)
+ self.coef = coef # for loss of rec, eos and ent respectively
+ # self.coef=(1.0, 0.0, 0.0)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ trunc_normal_(m.weight, std=.02)
+ try:
+ nn.init.constant_(m.bias, 0)
+ except:
+ pass
+
+ def forward(self, x, data=None):
+ if data is not None:
+ labels, label_lens = data
+ label_lens = label_lens + 1
+ max_char = label_lens.max()
+ else:
+ max_char = self.max_len
+
+ res_vis = self.decoder(x, max_char=max_char)
+ res_list = [res_vis]
+ if self.use_fem:
+ for it in range(self.iters):
+ char_feat_cntx = self.cntx_module(res_list[-1]['char_feats'],
+ res_list[-1]['char_masks'])
+ # import ipdb;ipdb.set_trace()
+ char_maps = res_list[-1]['char_maps']
+ if self.detach_grad:
+ char_maps = char_maps.detach()
+ feat_map = self.merge_layer(
+ x,
+ char_feat_cntx,
+ char_maps[:, :, :-1],
+ )
+ res_i = self.decoder(feat_map, max_char)
+ res_list.append(res_i)
+ if self.training:
+ loss_dict = self.get_loss(res_list[0], labels, label_lens)
+ for it in range(self.iters):
+ loss_dict_i = self.get_loss(res_list[it + 1], labels,
+ label_lens)
+ for k, v in loss_dict_i.items():
+ loss_dict[k] += v
+ else:
+ loss_dict = None
+ return [loss_dict, res_list[-1]]
+
+ def calc_rec_loss(self, logits, targets):
+ """
+ Args:
+ logits: [minibatch, C, T], not passed to the softmax func.
+ targets, torch.cuda.LongTensor [minibatch, T]
+ target_lens: [minibatch]
+ mask: [minibatch, T]
+ """
+ losses = self.celoss_fn(logits, targets)
+ return losses
+
+ def calc_eos_loc_loss(self, char_maps, target_lens, eps=1e-10):
+ max_tok = char_maps.shape[2]
+ eos_idx = (target_lens - 1).contiguous().view(-1, 1, 1).expand(
+ -1, 1, max_tok)
+ eos_maps = torch.gather(char_maps, dim=1,
+ index=eos_idx).squeeze(1) # (b, max_tok)
+ loss = (eos_maps[:, -1] + eps).log().neg()
+ return loss.mean()
+
+ def calc_entropy(self, p: torch.Tensor, mask: torch.Tensor, eps=1e-10):
+ """
+ Args:
+ p: probability distribution over the last dimension, of size (..., L, C)
+ mask: (..., L)
+ """
+ p_nlog = (p + eps).log().neg()
+ ent = p * p_nlog
+ ent = ent.sum(-1) / math.log(p.size(-1) + 1)
+ ent = (ent * mask).sum(-1) / (mask.sum(-1) + eps) # (...)
+ ent = ent.mean()
+ return ent
+
+ def get_loss(self, model_output, labels, label_lens):
+ labels = labels[:, :label_lens.max()]
+ batch_size, max_len = labels.size()
+ seq_range = torch.arange(
+ 0, max_len, device=labels.device).long().unsqueeze(0).expand(
+ batch_size, max_len)
+ seq_len = label_lens.unsqueeze(1).expand_as(seq_range)
+ mask = (seq_range < seq_len).float() # [batch_size, max_len]
+
+ l_rec = self.calc_rec_loss(model_output['logits'].transpose(1, 2),
+ labels)
+ l_eos = self.calc_eos_loc_loss(model_output['char_maps'], label_lens)
+ l_ent = self.calc_entropy(model_output['char_maps'], mask)
+
+ loss = l_rec * self.coef[0] + l_eos * self.coef[1] + l_ent * self.coef[
+ 2]
+ loss_dict = dict(
+ loss=loss,
+ l_rec=l_rec,
+ l_eos=l_eos,
+ l_ent=l_ent,
+ )
+ return loss_dict
diff --git a/openrec/modeling/decoders/lpv_decoder.py b/openrec/modeling/decoders/lpv_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc017eddf048fb76229e22929aa05b649a220ebc
--- /dev/null
+++ b/openrec/modeling/decoders/lpv_decoder.py
@@ -0,0 +1,119 @@
+import copy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .abinet_decoder import PositionAttention
+from .nrtr_decoder import PositionalEncoding, TransformerBlock
+
+
+class Trans(nn.Module):
+
+ def __init__(self, dim, nhead, dim_feedforward, dropout, num_layers):
+ super().__init__()
+ self.d_model = dim
+ self.nhead = nhead
+
+ self.pos_encoder = PositionalEncoding(dropout=0.0,
+ dim=self.d_model,
+ max_len=512)
+
+ self.transformer = nn.ModuleList([
+ TransformerBlock(
+ dim,
+ nhead,
+ dim_feedforward,
+ attention_dropout_rate=dropout,
+ residual_dropout_rate=dropout,
+ with_self_attn=True,
+ with_cross_attn=False,
+ ) for i in range(num_layers)
+ ])
+
+ def forward(self, feature, attn_map=None, use_mask=False):
+ n, c, h, w = feature.shape
+ feature = feature.flatten(2).transpose(1, 2)
+
+ if use_mask:
+ _, t, h, w = attn_map.shape
+ location_mask = (attn_map.view(n, t, -1).transpose(1, 2) >
+ 0.05).type(torch.float) # n,hw,t
+ location_mask = location_mask.bmm(location_mask.transpose(
+ 1, 2)) # n, hw, hw
+ location_mask = location_mask.new_zeros(
+ (h * w, h * w)).masked_fill(location_mask > 0, float('-inf'))
+ location_mask = location_mask.unsqueeze(1) # n, 1, hw, hw
+ else:
+ location_mask = None
+
+ feature = self.pos_encoder(feature)
+ for layer in self.transformer:
+ feature = layer(feature, self_mask=location_mask)
+ feature = feature.transpose(1, 2).view(n, c, h, w)
+ return feature, location_mask
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+class LPVDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_layer=3,
+ max_len=25,
+ use_mask=False,
+ dim_feedforward=1024,
+ nhead=8,
+ dropout=0.1,
+ trans_layer=2):
+ super().__init__()
+ self.use_mask = use_mask
+ self.max_len = max_len
+ attn_layer = PositionAttention(max_length=max_len + 1,
+ mode='nearest',
+ in_channels=in_channels,
+ num_channels=in_channels // 8)
+ trans_layer = Trans(dim=in_channels,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ num_layers=trans_layer)
+ cls_layer = nn.Linear(in_channels, out_channels - 2)
+
+ self.attention = _get_clones(attn_layer, num_layer)
+ self.trans = _get_clones(trans_layer, num_layer - 1)
+ self.cls = _get_clones(cls_layer, num_layer)
+
+ def forward(self, x, data=None):
+ if data is not None:
+ max_len = data[1].max()
+ else:
+ max_len = self.max_len
+ features = x # (N, E, H, W)
+
+ attn_vecs, attn_scores_map = self.attention[0](features)
+ attn_vecs = attn_vecs[:, :max_len + 1, :]
+ if not self.training:
+ for i in range(1, len(self.attention)):
+ features, mask = self.trans[i - 1](features,
+ attn_scores_map,
+ use_mask=self.use_mask)
+ attn_vecs, attn_scores_map = self.attention[i](
+ features, attn_vecs) # (N, T, E), (N, T, H, W)
+ return F.softmax(self.cls[-1](attn_vecs), -1)
+ else:
+ logits = []
+ logit = self.cls[0](attn_vecs) # (N, T, C)
+ logits.append(logit)
+ for i in range(1, len(self.attention)):
+ features, mask = self.trans[i - 1](features,
+ attn_scores_map,
+ use_mask=self.use_mask)
+ attn_vecs, attn_scores_map = self.attention[i](
+ features, attn_vecs) # (N, T, E), (N, T, H, W)
+ logit = self.cls[i](attn_vecs) # (N, T, C)
+ logits.append(logit)
+ return logits
diff --git a/openrec/modeling/decoders/matrn_decoder.py b/openrec/modeling/decoders/matrn_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8e69e6f9323805635eb6cc4a3008e3c13403ddc
--- /dev/null
+++ b/openrec/modeling/decoders/matrn_decoder.py
@@ -0,0 +1,236 @@
+"""This code is refer from:
+https://github.com/byeonghu-na/MATRN
+"""
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from openrec.modeling.decoders.abinet_decoder import BCNLanguage, PositionAttention, _get_length
+from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, TransformerBlock
+
+
+class BaseSemanticVisual_backbone_feature(nn.Module):
+
+ def __init__(self,
+ d_model=512,
+ nhead=8,
+ num_layers=4,
+ dim_feedforward=2048,
+ dropout=0.0,
+ alignment_mask_example_prob=0.9,
+ alignment_mask_candidate_prob=0.9,
+ alignment_num_vis_mask=10,
+ max_length=25,
+ num_classes=37):
+ super().__init__()
+ self.mask_example_prob = alignment_mask_example_prob
+ self.mask_candidate_prob = alignment_mask_candidate_prob #ifnone(config.model_alignment_mask_candidate_prob, 0.9)
+ self.num_vis_mask = alignment_num_vis_mask
+ self.nhead = nhead
+
+ self.d_model = d_model
+ self.max_length = max_length + 1 # additional stop token
+
+ self.model1 = nn.ModuleList([
+ TransformerBlock(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ attention_dropout_rate=dropout,
+ residual_dropout_rate=dropout,
+ with_self_attn=True,
+ with_cross_attn=False,
+ ) for i in range(num_layers)
+ ])
+ self.pos_encoder_tfm = PositionalEncoding(dim=d_model,
+ dropout=0,
+ max_len=1024)
+
+ self.model2_vis = PositionAttention(
+ max_length=self.max_length, # additional stop token
+ in_channels=d_model,
+ num_channels=d_model // 8,
+ mode='nearest',
+ )
+ self.cls_vis = nn.Linear(d_model, num_classes)
+ self.cls_sem = nn.Linear(d_model, num_classes)
+ self.w_att = nn.Linear(2 * d_model, d_model)
+
+ v_token = torch.empty((1, d_model))
+ self.v_token = nn.Parameter(v_token)
+ torch.nn.init.uniform_(self.v_token, -0.001, 0.001)
+
+ self.cls = nn.Linear(d_model, num_classes)
+
+ def forward(self, l_feature, v_feature, lengths_l=None, v_attn=None):
+ """
+ Args:
+ l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
+ v_feature: (N, E, H, W)
+ lengths_l: (N,)
+ v_attn: (N, T, H, W)
+ l_logits: (N, T, C)
+ texts: (N, T, C)
+ """
+
+ N, E, H, W = v_feature.size()
+ v_feature = v_feature.flatten(2, 3).transpose(1, 2) #(N, H*W, E)
+ v_attn = v_attn.flatten(2, 3) # (N, T, H*W)
+ if self.training:
+ for idx, length in enumerate(lengths_l):
+ if np.random.random() <= self.mask_example_prob:
+ l_idx = np.random.randint(int(length))
+ v_random_idx = v_attn[idx, l_idx].argsort(
+ descending=True).cpu().numpy()[:self.num_vis_mask, ]
+ v_random_idx = v_random_idx[np.random.random(
+ v_random_idx.shape) <= self.mask_candidate_prob]
+ v_feature[idx, v_random_idx] = self.v_token
+
+ zeros = v_feature.new_zeros((N, H * W, E)) # (N, H*W, E)
+ base_pos = self.pos_encoder_tfm(zeros) # (N, H*W, E)
+ base_pos = torch.bmm(v_attn, base_pos) # (N, T, E)
+
+ l_feature = l_feature + base_pos
+
+ sv_feature = torch.cat((v_feature, l_feature), dim=1) # (H*W+T, N, E)
+ for decoder_layer in self.model1:
+ sv_feature = decoder_layer(sv_feature) # (H*W+T, N, E)
+
+ sv_to_v_feature = sv_feature[:, :H * W] # (N, H*W, E)
+ sv_to_s_feature = sv_feature[:, H * W:] # (N, T, E)
+
+ sv_to_v_feature = sv_to_v_feature.transpose(1, 2).reshape(N, E, H, W)
+ sv_to_v_feature, _ = self.model2_vis(sv_to_v_feature) # (N, T, E)
+ sv_to_v_logits = self.cls_vis(sv_to_v_feature) # (N, T, C)
+ pt_v_lengths = _get_length(sv_to_v_logits) # (N,)
+
+ sv_to_s_logits = self.cls_sem(sv_to_s_feature) # (N, T, C)
+ pt_s_lengths = _get_length(sv_to_s_logits) # (N,)
+
+ f = torch.cat((sv_to_v_feature, sv_to_s_feature), dim=2)
+ f_att = torch.sigmoid(self.w_att(f))
+ output = f_att * sv_to_v_feature + (1 - f_att) * sv_to_s_feature
+
+ logits = self.cls(output) # (N, T, C)
+ pt_lengths = _get_length(logits)
+
+ return {
+ 'logits': logits,
+ 'pt_lengths': pt_lengths,
+ 'v_logits': sv_to_v_logits,
+ 'pt_v_lengths': pt_v_lengths,
+ 's_logits': sv_to_s_logits,
+ 'pt_s_lengths': pt_s_lengths,
+ 'name': 'alignment'
+ }
+
+
+class MATRNDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ nhead=8,
+ num_layers=3,
+ dim_feedforward=2048,
+ dropout=0.1,
+ max_length=25,
+ iter_size=3,
+ **kwargs):
+ super().__init__()
+ self.max_length = max_length + 1
+ d_model = in_channels
+ self.pos_encoder = PositionalEncoding(dropout=0.1, dim=d_model)
+ self.encoder = nn.ModuleList([
+ TransformerBlock(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ attention_dropout_rate=dropout,
+ residual_dropout_rate=dropout,
+ with_self_attn=True,
+ with_cross_attn=False,
+ ) for _ in range(num_layers)
+ ])
+ self.decoder = PositionAttention(
+ max_length=self.max_length, # additional stop token
+ in_channels=d_model,
+ num_channels=d_model // 8,
+ mode='nearest',
+ )
+ self.out_channels = out_channels
+ self.cls = nn.Linear(d_model, self.out_channels)
+ self.iter_size = iter_size
+ if iter_size > 0:
+ self.language = BCNLanguage(
+ d_model=d_model,
+ nhead=nhead,
+ num_layers=4,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ max_length=max_length,
+ num_classes=self.out_channels,
+ )
+ # alignment
+ self.semantic_visual = BaseSemanticVisual_backbone_feature(
+ d_model=d_model,
+ nhead=nhead,
+ num_layers=2,
+ dim_feedforward=dim_feedforward,
+ max_length=max_length,
+ num_classes=self.out_channels)
+
+ def forward(self, x, data=None):
+ # bs, c, h, w
+ x = x.permute([0, 2, 3, 1]) # bs, h, w, c
+ _, H, W, C = x.shape
+ # assert H % 8 == 0 and W % 16 == 0, 'The height and width should be multiples of 8 and 16.'
+ feature = x.flatten(1, 2) # bs, h*w, c
+ feature = self.pos_encoder(feature) # bs, h*w, c
+ for encoder_layer in self.encoder:
+ feature = encoder_layer(feature)
+ # bs, h*w, c
+ feature = feature.reshape([-1, H, W, C]).permute(0, 3, 1,
+ 2) # bs, c, h, w
+ v_feature, v_attn_input = self.decoder(feature) # (bs[N], T, E)
+ vis_logits = self.cls(v_feature) # (bs[N], T, E)
+ align_lengths = _get_length(vis_logits)
+ align_logits = vis_logits
+ all_l_res, all_a_res = [], []
+ for _ in range(self.iter_size):
+ tokens = F.softmax(align_logits, dim=-1)
+ lengths = torch.clamp(
+ align_lengths, 2,
+ self.max_length) # TODO: move to language model
+ l_feature, l_logits = self.language(tokens, lengths)
+ all_l_res.append(l_logits)
+ # alignment
+ lengths_l = _get_length(l_logits)
+ lengths_l.clamp_(2, self.max_length)
+
+ a_res = self.semantic_visual(l_feature,
+ feature,
+ lengths_l=lengths_l,
+ v_attn=v_attn_input)
+
+ a_v_res = a_res['v_logits']
+ # {'logits': a_res['v_logits'], 'pt_lengths': a_res['pt_v_lengths'], 'loss_weight': a_res['loss_weight'],
+ # 'name': 'alignment'}
+ all_a_res.append(a_v_res)
+ a_s_res = a_res['s_logits']
+ # {'logits': a_res['s_logits'], 'pt_lengths': a_res['pt_s_lengths'], 'loss_weight': a_res['loss_weight'],
+ # 'name': 'alignment'}
+ align_logits = a_res['logits']
+ all_a_res.append(a_s_res)
+ all_a_res.append(align_logits)
+ align_lengths = a_res['pt_lengths']
+ if self.training:
+ return {
+ 'align': all_a_res,
+ 'lang': all_l_res,
+ 'vision': vis_logits
+ }
+ else:
+ return F.softmax(align_logits, -1)
diff --git a/openrec/modeling/decoders/mgp_decoder.py b/openrec/modeling/decoders/mgp_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9822b064908b1acdf80a41e98336378a1008f1bb
--- /dev/null
+++ b/openrec/modeling/decoders/mgp_decoder.py
@@ -0,0 +1,99 @@
+'''
+This code is refer from:
+https://github.com/AlibabaResearch/AdvancedLiterateMachinery/blob/main/OCR/MGP-STR
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class TokenLearner(nn.Module):
+
+ def __init__(self, input_embed_dim, out_token=30):
+ super().__init__()
+ self.token_norm = nn.LayerNorm(input_embed_dim)
+ self.tokenLearner = nn.Sequential(
+ nn.Conv2d(input_embed_dim,
+ input_embed_dim,
+ kernel_size=(1, 1),
+ stride=1,
+ groups=8,
+ bias=False),
+ nn.Conv2d(input_embed_dim,
+ out_token,
+ kernel_size=(1, 1),
+ stride=1,
+ bias=False))
+ self.feat = nn.Conv2d(input_embed_dim,
+ input_embed_dim,
+ kernel_size=(1, 1),
+ stride=1,
+ groups=8,
+ bias=False)
+ self.norm = nn.LayerNorm(input_embed_dim)
+
+ def forward(self, x):
+ x = self.token_norm(x) # [bs, 257, 768]
+ x = x.transpose(1, 2).unsqueeze(-1) # [bs, 768, 257, 1]
+ selected = self.tokenLearner(x) # [bs, 27, 257, 1].
+ selected = selected.flatten(2) # [bs, 27, 257].
+ selected = F.softmax(selected, dim=-1)
+ feat = self.feat(x) # [bs, 768, 257, 1].
+ feat = feat.flatten(2).transpose(1, 2) # [bs, 257, 768]
+ x = torch.einsum('...si,...id->...sd', selected, feat) # [bs, 27, 768]
+
+ x = self.norm(x)
+ return selected, x
+
+
+class MGPDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ max_len=25,
+ only_char=False,
+ *args,
+ **kwargs):
+ super().__init__(*args, **kwargs)
+ num_classes = out_channels
+ embed_dim = in_channels
+ self.batch_max_length = max_len + 2
+ self.char_tokenLearner = TokenLearner(embed_dim, self.batch_max_length)
+ self.char_head = nn.Linear(
+ embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ self.only_char = only_char
+ if not only_char:
+ self.bpe_tokenLearner = TokenLearner(embed_dim,
+ self.batch_max_length)
+ self.wp_tokenLearner = TokenLearner(embed_dim,
+ self.batch_max_length)
+ self.bpe_head = nn.Linear(
+ embed_dim, 50257) if num_classes > 0 else nn.Identity()
+ self.wp_head = nn.Linear(
+ embed_dim, 30522) if num_classes > 0 else nn.Identity()
+
+ def forward(self, x, data=None):
+ # attens = []
+ # char
+ char_attn, x_char = self.char_tokenLearner(x)
+ x_char = self.char_head(x_char)
+ char_out = x_char
+ # attens = [char_attn]
+ if not self.only_char:
+ # bpe
+ bpe_attn, x_bpe = self.bpe_tokenLearner(x)
+ bpe_out = self.bpe_head(x_bpe)
+ # attens += [bpe_attn]
+ # wp
+ wp_attn, x_wp = self.wp_tokenLearner(x)
+ wp_out = self.wp_head(x_wp)
+ return [char_out, bpe_out, wp_out] if self.training else [
+ F.softmax(char_out, -1),
+ F.softmax(bpe_out, -1),
+ F.softmax(wp_out, -1)
+ ]
+ # attens += [wp_attn]
+
+ return char_out if self.training else F.softmax(char_out, -1)
diff --git a/openrec/modeling/decoders/nrtr_decoder.py b/openrec/modeling/decoders/nrtr_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c7e4ddfd8cff53516f0d984f8c23be1e0da6ced
--- /dev/null
+++ b/openrec/modeling/decoders/nrtr_decoder.py
@@ -0,0 +1,439 @@
+import math
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from openrec.modeling.common import Mlp
+
+
+class NRTRDecoder(nn.Module):
+ """A transformer model. User is able to modify the attributes as needed.
+ The architechture is based on the paper "Attention Is All You Need". Ashish
+ Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N
+ Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you
+ need. In Advances in Neural Information Processing Systems, pages
+ 6000-6010.
+
+ Args:
+ d_model: the number of expected features in the encoder/decoder inputs (default=512).
+ nhead: the number of heads in the multiheadattention models (default=8).
+ num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
+ num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+ custom_encoder: custom encoder (default=None).
+ custom_decoder: custom decoder (default=None).
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ nhead=None,
+ num_encoder_layers=6,
+ beam_size=0,
+ num_decoder_layers=6,
+ max_len=25,
+ attention_dropout_rate=0.0,
+ residual_dropout_rate=0.1,
+ scale_embedding=True,
+ ):
+ super(NRTRDecoder, self).__init__()
+ self.out_channels = out_channels
+ self.ignore_index = out_channels - 1
+ self.bos = out_channels - 2
+ self.eos = 0
+ self.max_len = max_len
+ d_model = in_channels
+ dim_feedforward = d_model * 4
+ nhead = nhead if nhead is not None else d_model // 32
+ self.embedding = Embeddings(
+ d_model=d_model,
+ vocab=self.out_channels,
+ padding_idx=0,
+ scale_embedding=scale_embedding,
+ )
+ self.positional_encoding = PositionalEncoding(
+ dropout=residual_dropout_rate, dim=d_model)
+
+ if num_encoder_layers > 0:
+ self.encoder = nn.ModuleList([
+ TransformerBlock(
+ d_model,
+ nhead,
+ dim_feedforward,
+ attention_dropout_rate,
+ residual_dropout_rate,
+ with_self_attn=True,
+ with_cross_attn=False,
+ ) for i in range(num_encoder_layers)
+ ])
+ else:
+ self.encoder = None
+
+ self.decoder = nn.ModuleList([
+ TransformerBlock(
+ d_model,
+ nhead,
+ dim_feedforward,
+ attention_dropout_rate,
+ residual_dropout_rate,
+ with_self_attn=True,
+ with_cross_attn=True,
+ ) for i in range(num_decoder_layers)
+ ])
+
+ self.beam_size = beam_size
+ self.d_model = d_model
+ self.nhead = nhead
+ self.tgt_word_prj = nn.Linear(d_model,
+ self.out_channels - 2,
+ bias=False)
+ w0 = np.random.normal(0.0, d_model**-0.5,
+ (d_model, self.out_channels - 2)).astype(
+ np.float32)
+ self.tgt_word_prj.weight.data = torch.from_numpy(w0.transpose())
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward_train(self, src, tgt):
+ tgt = tgt[:, :-1]
+
+ tgt = self.embedding(tgt)
+ tgt = self.positional_encoding(tgt)
+ tgt_mask = self.generate_square_subsequent_mask(
+ tgt.shape[1], device=src.get_device())
+
+ if self.encoder is not None:
+ src = self.positional_encoding(src)
+ for encoder_layer in self.encoder:
+ src = encoder_layer(src)
+ memory = src # B N C
+ else:
+ memory = src # B N C
+ for decoder_layer in self.decoder:
+ tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
+ output = tgt
+ logit = self.tgt_word_prj(output)
+ return logit
+
+ def forward(self, src, data=None):
+ """Take in and process masked source/target sequences.
+ Args:
+ src: the sequence to the encoder (required).
+ tgt: the sequence to the decoder (required).
+ Shape:
+ - src: :math:`(B, sN, C)`.
+ - tgt: :math:`(B, tN, C)`.
+ Examples:
+ >>> output = transformer_model(src, tgt)
+ """
+
+ if self.training:
+ max_len = data[1].max()
+ tgt = data[0][:, :2 + max_len]
+ res = self.forward_train(src, tgt)
+ else:
+ res = self.forward_test(src)
+ return res
+
+ def forward_test(self, src):
+ bs = src.shape[0]
+ if self.encoder is not None:
+ src = self.positional_encoding(src)
+ for encoder_layer in self.encoder:
+ src = encoder_layer(src)
+ memory = src # B N C
+ else:
+ memory = src
+ dec_seq = torch.full((bs, self.max_len + 1),
+ self.ignore_index,
+ dtype=torch.int64,
+ device=src.get_device())
+ dec_seq[:, 0] = self.bos
+ logits = []
+ self.attn_maps = []
+ for len_dec_seq in range(0, self.max_len):
+ dec_seq_embed = self.embedding(
+ dec_seq[:, :len_dec_seq + 1]) # N dim 26+10 # 012 a
+ dec_seq_embed = self.positional_encoding(dec_seq_embed)
+ tgt_mask = self.generate_square_subsequent_mask(
+ dec_seq_embed.shape[1], src.get_device())
+ tgt = dec_seq_embed # bs, 3, dim #bos, a, b, c, ... eos
+ for decoder_layer in self.decoder:
+ tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
+ self.attn_maps.append(
+ self.decoder[-1].cross_attn.attn_map[0][:, -1:, :])
+ dec_output = tgt
+ dec_output = dec_output[:, -1:, :]
+
+ word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1)
+ logits.append(word_prob)
+ if len_dec_seq < self.max_len:
+ # greedy decode. add the next token index to the target input
+ dec_seq[:, len_dec_seq + 1] = word_prob.squeeze().argmax(-1)
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
+ if (dec_seq == self.eos).any(dim=-1).all():
+ break
+ logits = torch.cat(logits, dim=1)
+ return logits
+
+ def generate_square_subsequent_mask(self, sz, device):
+ """Generate a square mask for the sequence.
+
+ The masked positions are filled with float('-inf'). Unmasked positions
+ are filled with float(0.0).
+ """
+ mask = torch.zeros([sz, sz], dtype=torch.float32)
+ mask_inf = torch.triu(
+ torch.full((sz, sz), dtype=torch.float32, fill_value=-torch.inf),
+ diagonal=1,
+ )
+ mask = mask + mask_inf
+ return mask.unsqueeze(0).unsqueeze(0).to(device)
+
+
+class MultiheadAttention(nn.Module):
+
+ def __init__(self, embed_dim, num_heads, dropout=0.0, self_attn=False):
+ super(MultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.head_dim = embed_dim // num_heads
+ assert (self.head_dim * num_heads == self.embed_dim
+ ), 'embed_dim must be divisible by num_heads'
+ self.scale = self.head_dim**-0.5
+ self.self_attn = self_attn
+ if self_attn:
+ self.qkv = nn.Linear(embed_dim, embed_dim * 3)
+ else:
+ self.q = nn.Linear(embed_dim, embed_dim)
+ self.kv = nn.Linear(embed_dim, embed_dim * 2)
+ self.attn_drop = nn.Dropout(dropout)
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
+
+ def forward(self, query, key=None, attn_mask=None):
+ B, qN = query.shape[:2]
+
+ if self.self_attn:
+ qkv = self.qkv(query)
+ qkv = qkv.reshape(B, qN, 3, self.num_heads,
+ self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+ else:
+ kN = key.shape[1]
+ q = self.q(query)
+ q = q.reshape(B, qN, self.num_heads, self.head_dim).transpose(1, 2)
+ kv = self.kv(key)
+ kv = kv.reshape(B, kN, 2, self.num_heads,
+ self.head_dim).permute(2, 0, 3, 1, 4)
+ k, v = kv.unbind(0)
+
+ attn = (q.matmul(k.transpose(2, 3))) * self.scale
+ if attn_mask is not None:
+ attn += attn_mask
+
+ attn = F.softmax(attn, dim=-1)
+ if not self.training:
+ self.attn_map = attn
+ attn = self.attn_drop(attn)
+
+ x = (attn.matmul(v)).transpose(1, 2)
+ x = x.reshape(B, qN, self.embed_dim)
+ x = self.out_proj(x)
+
+ return x
+
+
+class TransformerBlock(nn.Module):
+
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ attention_dropout_rate=0.0,
+ residual_dropout_rate=0.1,
+ with_self_attn=True,
+ with_cross_attn=False,
+ epsilon=1e-5,
+ ):
+ super(TransformerBlock, self).__init__()
+ self.with_self_attn = with_self_attn
+ if with_self_attn:
+ self.self_attn = MultiheadAttention(d_model,
+ nhead,
+ dropout=attention_dropout_rate,
+ self_attn=with_self_attn)
+ self.norm1 = nn.LayerNorm(d_model, eps=epsilon)
+ self.dropout1 = nn.Dropout(residual_dropout_rate)
+ self.with_cross_attn = with_cross_attn
+ if with_cross_attn:
+ self.cross_attn = MultiheadAttention(
+ d_model, nhead, dropout=attention_dropout_rate
+ ) # for self_attn of encoder or cross_attn of decoder
+ self.norm2 = nn.LayerNorm(d_model, eps=epsilon)
+ self.dropout2 = nn.Dropout(residual_dropout_rate)
+
+ self.mlp = Mlp(
+ in_features=d_model,
+ hidden_features=dim_feedforward,
+ act_layer=nn.ReLU,
+ drop=residual_dropout_rate,
+ )
+
+ self.norm3 = nn.LayerNorm(d_model, eps=epsilon)
+
+ self.dropout3 = nn.Dropout(residual_dropout_rate)
+
+ def forward(self, tgt, memory=None, self_mask=None, cross_mask=None):
+ if self.with_self_attn:
+ tgt1 = self.self_attn(tgt, attn_mask=self_mask)
+ tgt = self.norm1(tgt + self.dropout1(tgt1))
+
+ if self.with_cross_attn:
+ tgt2 = self.cross_attn(tgt, key=memory, attn_mask=cross_mask)
+ tgt = self.norm2(tgt + self.dropout2(tgt2))
+ tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
+ return tgt
+
+
+class PositionalEncoding(nn.Module):
+ """Inject some information about the relative or absolute position of the
+ tokens in the sequence. The positional encodings have the same dimension as
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
+ functions of different frequencies.
+
+ .. math::
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
+ \text{where pos is the word position and i is the embed idx)
+ Args:
+ d_model: the embed dim (required).
+ dropout: the dropout value (default=0.1).
+ max_len: the max. length of the incoming sequence (default=5000).
+ Examples:
+ >>> pos_encoder = PositionalEncoding(d_model)
+ """
+
+ def __init__(self, dropout, dim, max_len=5000):
+ super(PositionalEncoding, self).__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ pe = torch.zeros([max_len, dim])
+ position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = torch.unsqueeze(pe, 0)
+ # pe = torch.permute(pe, [1, 0, 2])
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ """Inputs of forward function
+ Args:
+ x: the sequence fed to the positional encoder model (required).
+ Shape:
+ x: [sequence length, batch size, embed dim]
+ output: [sequence length, batch size, embed dim]
+ Examples:
+ >>> output = pos_encoder(x)
+ """
+ # x = x.permute([1, 0, 2])
+ # x = x + self.pe[:x.shape[0], :]
+ x = x + self.pe[:, :x.shape[1], :]
+ return self.dropout(x) # .permute([1, 0, 2])
+
+
+class PositionalEncoding_2d(nn.Module):
+ """Inject some information about the relative or absolute position of the
+ tokens in the sequence. The positional encodings have the same dimension as
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
+ functions of different frequencies.
+
+ .. math::
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
+ \text{where pos is the word position and i is the embed idx)
+ Args:
+ d_model: the embed dim (required).
+ dropout: the dropout value (default=0.1).
+ max_len: the max. length of the incoming sequence (default=5000).
+ Examples:
+ >>> pos_encoder = PositionalEncoding(d_model)
+ """
+
+ def __init__(self, dropout, dim, max_len=5000):
+ super(PositionalEncoding_2d, self).__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ pe = torch.zeros([max_len, dim])
+ position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = torch.permute(torch.unsqueeze(pe, 0), [1, 0, 2])
+ self.register_buffer('pe', pe)
+
+ self.avg_pool_1 = nn.AdaptiveAvgPool2d((1, 1))
+ self.linear1 = nn.Linear(dim, dim)
+ self.linear1.weight.data.fill_(1.0)
+ self.avg_pool_2 = nn.AdaptiveAvgPool2d((1, 1))
+ self.linear2 = nn.Linear(dim, dim)
+ self.linear2.weight.data.fill_(1.0)
+
+ def forward(self, x):
+ """Inputs of forward function
+ Args:
+ x: the sequence fed to the positional encoder model (required).
+ Shape:
+ x: [sequence length, batch size, embed dim]
+ output: [sequence length, batch size, embed dim]
+ Examples:
+ >>> output = pos_encoder(x)
+ """
+ w_pe = self.pe[:x.shape[-1], :]
+ w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
+ w_pe = w_pe * w1
+ w_pe = torch.permute(w_pe, [1, 2, 0])
+ w_pe = torch.unsqueeze(w_pe, 2)
+
+ h_pe = self.pe[:x.shape[-2], :]
+ w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
+ h_pe = h_pe * w2
+ h_pe = torch.permute(h_pe, [1, 2, 0])
+ h_pe = torch.unsqueeze(h_pe, 3)
+
+ x = x + w_pe + h_pe
+ x = torch.permute(
+ torch.reshape(x,
+ [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
+ [2, 0, 1],
+ )
+
+ return self.dropout(x)
+
+
+class Embeddings(nn.Module):
+
+ def __init__(self, d_model, vocab, padding_idx=None, scale_embedding=True):
+ super(Embeddings, self).__init__()
+ self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
+ self.embedding.weight.data.normal_(mean=0.0, std=d_model**-0.5)
+ self.d_model = d_model
+ self.scale_embedding = scale_embedding
+
+ def forward(self, x):
+ if self.scale_embedding:
+ x = self.embedding(x)
+ return x * math.sqrt(self.d_model)
+ return self.embedding(x)
diff --git a/openrec/modeling/decoders/ote_decoder.py b/openrec/modeling/decoders/ote_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..5db4c4c10b4fd126d84b54177015ae6489f93bae
--- /dev/null
+++ b/openrec/modeling/decoders/ote_decoder.py
@@ -0,0 +1,205 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.init import ones_, trunc_normal_, zeros_
+
+from .nrtr_decoder import TransformerBlock, Embeddings
+
+
+class CPA(nn.Module):
+
+ def __init__(self, dim, max_len=25):
+ super(CPA, self).__init__()
+
+ self.fc1 = nn.Linear(dim, dim)
+ self.fc2 = nn.Linear(dim, dim)
+ self.fc3 = nn.Linear(dim, dim)
+ self.pos_embed = nn.Parameter(torch.zeros([1, max_len + 1, dim],
+ dtype=torch.float32),
+ requires_grad=True)
+ trunc_normal_(self.pos_embed, std=0.02)
+
+ def forward(self, feat):
+ # feat: B, L, Dim
+ feat = feat.mean(1).unsqueeze(1) # B, 1, Dim
+ x = self.fc1(feat) + self.pos_embed # B max_len dim
+ x = F.softmax(self.fc2(F.tanh(x)), -1) # B max_len dim
+ x = self.fc3(feat * x) # B max_len dim
+ return x
+
+
+class ARDecoder(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ nhead=None,
+ num_decoder_layers=6,
+ max_len=25,
+ attention_dropout_rate=0.0,
+ residual_dropout_rate=0.1,
+ scale_embedding=True,
+ ):
+ super(ARDecoder, self).__init__()
+ self.out_channels = out_channels
+ self.ignore_index = out_channels - 1
+ self.bos = out_channels - 2
+ self.eos = 0
+ self.max_len = max_len
+ d_model = in_channels
+ dim_feedforward = d_model * 4
+ nhead = nhead if nhead is not None else d_model // 32
+ self.embedding = Embeddings(
+ d_model=d_model,
+ vocab=self.out_channels,
+ padding_idx=0,
+ scale_embedding=scale_embedding,
+ )
+ self.pos_embed = nn.Parameter(torch.zeros([1, max_len + 1, d_model],
+ dtype=torch.float32),
+ requires_grad=True)
+ trunc_normal_(self.pos_embed, std=0.02)
+ self.decoder = nn.ModuleList([
+ TransformerBlock(
+ d_model,
+ nhead,
+ dim_feedforward,
+ attention_dropout_rate,
+ residual_dropout_rate,
+ with_self_attn=True,
+ with_cross_attn=False,
+ ) for i in range(num_decoder_layers)
+ ])
+
+ self.tgt_word_prj = nn.Linear(d_model,
+ self.out_channels - 2,
+ bias=False)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward_train(self, src, tgt):
+ tgt = tgt[:, :-1]
+
+ tgt = self.embedding(
+ tgt) + src[:, :tgt.shape[1]] + self.pos_embed[:, :tgt.shape[1]]
+ tgt_mask = self.generate_square_subsequent_mask(
+ tgt.shape[1], device=src.get_device())
+
+ for decoder_layer in self.decoder:
+ tgt = decoder_layer(tgt, self_mask=tgt_mask)
+ output = tgt
+ logit = self.tgt_word_prj(output)
+ return logit
+
+ def forward(self, src, data=None):
+
+ if self.training:
+ max_len = data[1].max()
+ tgt = data[0][:, :2 + max_len]
+ res = self.forward_train(src, tgt)
+ else:
+ res = self.forward_test(src)
+ return res
+
+ def forward_test(self, src):
+ bs = src.shape[0]
+ src = src + self.pos_embed
+ dec_seq = torch.full((bs, self.max_len + 1),
+ self.ignore_index,
+ dtype=torch.int64,
+ device=src.get_device())
+ dec_seq[:, 0] = self.bos
+ logits = []
+ for len_dec_seq in range(0, self.max_len):
+ dec_seq_embed = self.embedding(
+ dec_seq[:, :len_dec_seq + 1]) # N dim 26+10 # 012 a
+ dec_seq_embed = dec_seq_embed + src[:, :len_dec_seq + 1]
+ tgt_mask = self.generate_square_subsequent_mask(
+ dec_seq_embed.shape[1], src.get_device())
+ tgt = dec_seq_embed # bs, 3, dim #bos, a, b, c, ... eos
+ for decoder_layer in self.decoder:
+ tgt = decoder_layer(tgt, self_mask=tgt_mask)
+ dec_output = tgt
+ dec_output = dec_output[:, -1:, :]
+ word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1)
+ logits.append(word_prob)
+ if len_dec_seq < self.max_len:
+ # greedy decode. add the next token index to the target input
+ dec_seq[:, len_dec_seq + 1] = word_prob.squeeze(1).argmax(-1)
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
+ if (dec_seq == self.eos).any(dim=-1).all():
+ break
+ logits = torch.cat(logits, dim=1)
+ return logits
+
+ def generate_square_subsequent_mask(self, sz, device):
+ """Generate a square mask for the sequence.
+
+ The masked positions are filled with float('-inf'). Unmasked positions
+ are filled with float(0.0).
+ """
+ mask = torch.zeros([sz, sz], dtype=torch.float32)
+ mask_inf = torch.triu(
+ torch.full((sz, sz), dtype=torch.float32, fill_value=-torch.inf),
+ diagonal=1,
+ )
+ mask = mask + mask_inf
+ return mask.unsqueeze(0).unsqueeze(0).to(device)
+
+
+class OTEDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ max_len=25,
+ num_heads=None,
+ ar=False,
+ num_decoder_layers=1,
+ **kwargs):
+ super(OTEDecoder, self).__init__()
+
+ self.out_channels = out_channels - 2 # none + 26 + 10
+ dim = in_channels
+ self.dim = dim
+ self.max_len = max_len + 1 # max_len + eos
+
+ self.cpa = CPA(dim=dim, max_len=max_len)
+ self.ar = ar
+ if ar:
+ self.ar_decoder = ARDecoder(in_channels=dim,
+ out_channels=out_channels,
+ nhead=num_heads,
+ num_decoder_layers=num_decoder_layers,
+ max_len=max_len)
+ else:
+ self.fc = nn.Linear(dim, self.out_channels)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed'}
+
+ def forward(self, x, data=None):
+ x = self.cpa(x)
+ if self.ar:
+ return self.ar_decoder(x, data=data)
+ logits = self.fc(x) # B, 26, 37
+ if self.training:
+ logits = logits[:, :data[1].max() + 1]
+ return logits
diff --git a/openrec/modeling/decoders/parseq_decoder.py b/openrec/modeling/decoders/parseq_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..64d89769a7597db02837b6e860ed557fd1a0c09b
--- /dev/null
+++ b/openrec/modeling/decoders/parseq_decoder.py
@@ -0,0 +1,504 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from itertools import permutations
+from typing import Any, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.nn.modules import transformer
+
+
+class DecoderLayer(nn.Module):
+ """A Transformer decoder layer supporting two-stream attention (XLNet) This
+ implements a pre-LN decoder, as opposed to the post-LN default in
+ PyTorch."""
+
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation='gelu',
+ layer_norm_eps=1e-5,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model,
+ nhead,
+ dropout=dropout,
+ batch_first=True)
+ self.cross_attn = nn.MultiheadAttention(d_model,
+ nhead,
+ dropout=dropout,
+ batch_first=True)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = transformer._get_activation_fn(activation)
+
+ def __setstate__(self, state):
+ if 'activation' not in state:
+ state['activation'] = F.gelu
+ super().__setstate__(state)
+
+ def forward_stream(
+ self,
+ tgt: Tensor,
+ tgt_norm: Tensor,
+ tgt_kv: Tensor,
+ memory: Tensor,
+ tgt_mask: Optional[Tensor],
+ tgt_key_padding_mask: Optional[Tensor],
+ ):
+ """Forward pass for a single stream (i.e. content or query) tgt_norm is
+ just a LayerNorm'd tgt.
+
+ Added as a separate parameter for efficiency. Both tgt_kv and memory
+ are expected to be LayerNorm'd too. memory is LayerNorm'd by ViT.
+ """
+ tgt2, sa_weights = self.self_attn(
+ tgt_norm,
+ tgt_kv,
+ tgt_kv,
+ attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)
+ tgt = tgt + self.dropout1(tgt2)
+
+ tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
+ self.attn_map = ca_weights
+ tgt = tgt + self.dropout2(tgt2)
+
+ tgt2 = self.linear2(
+ self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt, sa_weights, ca_weights
+
+ def forward(
+ self,
+ query,
+ content,
+ memory,
+ query_mask: Optional[Tensor] = None,
+ content_mask: Optional[Tensor] = None,
+ content_key_padding_mask: Optional[Tensor] = None,
+ update_content: bool = True,
+ ):
+ query_norm = self.norm_q(query)
+ content_norm = self.norm_c(content)
+ query = self.forward_stream(query, query_norm, content_norm, memory,
+ query_mask, content_key_padding_mask)[0]
+ if update_content:
+ content = self.forward_stream(content, content_norm, content_norm,
+ memory, content_mask,
+ content_key_padding_mask)[0]
+ return query, content
+
+
+class Decoder(nn.Module):
+ __constants__ = ['norm']
+
+ def __init__(self, decoder_layer, num_layers, norm):
+ super().__init__()
+ self.layers = transformer._get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(
+ self,
+ query,
+ content,
+ memory,
+ query_mask: Optional[Tensor] = None,
+ content_mask: Optional[Tensor] = None,
+ content_key_padding_mask: Optional[Tensor] = None,
+ ):
+ for i, mod in enumerate(self.layers):
+ last = i == len(self.layers) - 1
+ query, content = mod(
+ query,
+ content,
+ memory,
+ query_mask,
+ content_mask,
+ content_key_padding_mask,
+ update_content=not last,
+ )
+ query = self.norm(query)
+ return query
+
+
+class TokenEmbedding(nn.Module):
+
+ def __init__(self, charset_size: int, embed_dim: int):
+ super().__init__()
+ self.embedding = nn.Embedding(charset_size, embed_dim)
+ self.embed_dim = embed_dim
+
+ def forward(self, tokens: torch.Tensor):
+ return math.sqrt(self.embed_dim) * self.embedding(tokens)
+
+
+class PARSeqDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ max_label_length=25,
+ embed_dim=384,
+ dec_num_heads=12,
+ dec_mlp_ratio=4,
+ dec_depth=1,
+ perm_num=6,
+ perm_forward=True,
+ perm_mirrored=True,
+ decode_ar=True,
+ refine_iters=1,
+ dropout=0.1,
+ **kwargs: Any) -> None:
+ super().__init__()
+ self.pad_id = out_channels - 1
+ self.eos_id = 0
+ self.bos_id = out_channels - 2
+ self.max_label_length = max_label_length
+ self.decode_ar = decode_ar
+ self.refine_iters = refine_iters
+
+ decoder_layer = DecoderLayer(embed_dim, dec_num_heads,
+ embed_dim * dec_mlp_ratio, dropout)
+ self.decoder = Decoder(decoder_layer,
+ num_layers=dec_depth,
+ norm=nn.LayerNorm(embed_dim))
+
+ # Perm/attn mask stuff
+ self.rng = np.random.default_rng()
+ self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
+ self.perm_forward = perm_forward
+ self.perm_mirrored = perm_mirrored
+
+ # We don't predict nor
+ self.head = nn.Linear(embed_dim, out_channels - 2)
+ self.text_embed = TokenEmbedding(out_channels, embed_dim)
+
+ # +1 for
+ self.pos_queries = nn.Parameter(
+ torch.Tensor(1, max_label_length + 1, embed_dim))
+ self.dropout = nn.Dropout(p=dropout)
+ # Encoder has its own init.
+ self.apply(self._init_weights)
+ nn.init.trunc_normal_(self.pos_queries, std=0.02)
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights using the typical initialization schemes used
+ in SOTA models."""
+
+ if isinstance(module, nn.Linear):
+ nn.init.trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Embedding):
+ nn.init.trunc_normal_(module.weight, std=0.02)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.Conv2d):
+ nn.init.kaiming_normal_(module.weight,
+ mode='fan_out',
+ nonlinearity='relu')
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.ones_(module.weight)
+ nn.init.zeros_(module.bias)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ param_names = {'text_embed.embedding.weight', 'pos_queries'}
+ return param_names
+
+ def decode(
+ self,
+ tgt: torch.Tensor,
+ memory: torch.Tensor,
+ tgt_mask: Optional[Tensor] = None,
+ tgt_padding_mask: Optional[Tensor] = None,
+ tgt_query: Optional[Tensor] = None,
+ tgt_query_mask: Optional[Tensor] = None,
+ pos_query: torch.Tensor = None,
+ ):
+ N, L = tgt.shape
+ # stands for the null context. We only supply position information for characters after .
+ null_ctx = self.text_embed(tgt[:, :1])
+
+ if tgt_query is None:
+ tgt_query = pos_query[:, :L]
+ tgt_emb = pos_query[:, :L - 1] + self.text_embed(tgt[:, 1:])
+ tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
+
+ tgt_query = self.dropout(tgt_query)
+ return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask,
+ tgt_mask, tgt_padding_mask)
+
+ def forward(self, x, data=None, pos_query=None):
+ if self.training:
+ return self.training_step([x, pos_query, data[0]])
+ else:
+ return self.forward_test(x, pos_query)
+
+ def forward_test(self,
+ memory: Tensor,
+ pos_query: Tensor = None,
+ max_length: Optional[int] = None) -> Tensor:
+ _device = memory.get_device()
+ testing = max_length is None
+ max_length = (self.max_label_length if max_length is None else min(
+ max_length, self.max_label_length))
+ bs = memory.shape[0]
+ # +1 for at end of sequence.
+ num_steps = max_length + 1
+ # memory = self.encode(images)
+
+ # Query positions up to `num_steps`
+ if pos_query is None:
+ pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
+ else:
+ pos_queries = pos_query
+
+ # Special case for the forward permutation. Faster than using `generate_attn_masks()`
+ tgt_mask = query_mask = torch.triu(
+ torch.full((num_steps, num_steps), float('-inf'), device=_device),
+ 1)
+ self.attn_maps = []
+ if self.decode_ar:
+ tgt_in = torch.full((bs, num_steps),
+ self.pad_id,
+ dtype=torch.long,
+ device=_device)
+ tgt_in[:, 0] = self.bos_id
+
+ logits = []
+ for i in range(num_steps):
+ j = i + 1 # next token index
+ # Efficient decoding:
+ # Input the context up to the ith token. We use only one query (at position = i) at a time.
+ # This works because of the lookahead masking effect of the canonical (forward) AR context.
+ # Past tokens have no access to future tokens, hence are fixed once computed.
+ tgt_out = self.decode(
+ tgt_in[:, :j],
+ memory,
+ tgt_mask[:j, :j],
+ tgt_query=pos_queries[:, i:j],
+ tgt_query_mask=query_mask[i:j, :j],
+ pos_query=pos_queries,
+ )
+ self.attn_maps.append(self.decoder.layers[-1].attn_map)
+ # the next token probability is in the output's ith token position
+ p_i = self.head(tgt_out)
+ logits.append(p_i)
+ if j < num_steps:
+ # greedy decode. add the next token index to the target input
+ tgt_in[:, j] = p_i.squeeze().argmax(-1)
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
+ if testing and (tgt_in == self.eos_id).any(dim=-1).all():
+ break
+
+ logits = torch.cat(logits, dim=1)
+ else:
+ # No prior context, so input is just . We query all positions.
+ tgt_in = torch.full((bs, 1),
+ self.bos_id,
+ dtype=torch.long,
+ device=_device)
+ tgt_out = self.decode(tgt_in,
+ memory,
+ tgt_query=pos_queries,
+ pos_query=pos_queries)
+ logits = self.head(tgt_out)
+
+ if self.refine_iters:
+ # For iterative refinement, we always use a 'cloze' mask.
+ # We can derive it from the AR forward mask by unmasking the token context to the right.
+ query_mask[torch.triu(
+ torch.ones(num_steps,
+ num_steps,
+ dtype=torch.bool,
+ device=_device), 2)] = 0
+ bos = torch.full((bs, 1),
+ self.bos_id,
+ dtype=torch.long,
+ device=_device)
+ for i in range(self.refine_iters):
+ # Prior context is the previous output.
+ tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
+ tgt_padding_mask = (tgt_in == self.eos_id).int().cumsum(
+ -1) > 0 # mask tokens beyond the first EOS token.
+ tgt_out = self.decode(
+ tgt_in,
+ memory,
+ tgt_mask,
+ tgt_padding_mask,
+ tgt_query=pos_queries,
+ tgt_query_mask=query_mask[:, :tgt_in.shape[1]],
+ pos_query=pos_queries,
+ )
+ logits = self.head(tgt_out)
+
+ return F.softmax(logits, -1)
+
+ def gen_tgt_perms(self, tgt, _device):
+ """Generate shared permutations for the whole batch.
+
+ This works because the same attention mask can be used for the shorter
+ sequences because of the padding mask.
+ """
+ # We don't permute the position of BOS, we permute EOS separately
+ max_num_chars = tgt.shape[1] - 2
+ # Special handling for 1-character sequences
+ if max_num_chars == 1:
+ return torch.arange(3, device=_device).unsqueeze(0)
+ perms = [torch.arange(max_num_chars, device=_device)
+ ] if self.perm_forward else []
+ # Additional permutations if needed
+ max_perms = math.factorial(max_num_chars)
+ if self.perm_mirrored:
+ max_perms //= 2
+ num_gen_perms = min(self.max_gen_perms, max_perms)
+ # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions
+ # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars.
+ if max_num_chars < 5:
+ # Pool of permutations to sample from. We only need the first half (if complementary option is selected)
+ # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
+ if max_num_chars == 4 and self.perm_mirrored:
+ selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
+ else:
+ selector = list(range(max_perms))
+ perm_pool = torch.as_tensor(list(
+ permutations(range(max_num_chars), max_num_chars)),
+ device=_device)[selector]
+ # If the forward permutation is always selected, no need to add it to the pool for sampling
+ if self.perm_forward:
+ perm_pool = perm_pool[1:]
+ perms = torch.stack(perms)
+ if len(perm_pool):
+ i = self.rng.choice(len(perm_pool),
+ size=num_gen_perms - len(perms),
+ replace=False)
+ perms = torch.cat([perms, perm_pool[i]])
+ else:
+ perms.extend([
+ torch.randperm(max_num_chars, device=_device)
+ for _ in range(num_gen_perms - len(perms))
+ ])
+ perms = torch.stack(perms)
+ if self.perm_mirrored:
+ # Add complementary pairs
+ comp = perms.flip(-1)
+ # Stack in such a way that the pairs are next to each other.
+ perms = torch.stack([perms, comp
+ ]).transpose(0, 1).reshape(-1, max_num_chars)
+ # NOTE:
+ # The only meaningful way of permuting the EOS position is by moving it one character position at a time.
+ # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS
+ # positions will always be much less than the number of permutations (unless a low perm_num is set).
+ # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly
+ # distribute it across the chosen number of permutations.
+ # Add position indices of BOS and EOS
+ bos_idx = perms.new_zeros((len(perms), 1))
+ eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1)
+ perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1)
+ # Special handling for the reverse direction. This does two things:
+ # 1. Reverse context for the characters
+ # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode)
+ if len(perms) > 1:
+ perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1,
+ device=_device)
+ return perms
+
+ def generate_attn_masks(self, perm, _device):
+ """Generate attention masks given a sequence permutation (includes pos.
+ for bos and eos tokens)
+
+ :param perm: the permutation sequence. i = 0 is always the BOS
+ :return: lookahead attention masks
+ """
+ sz = perm.shape[0]
+ mask = torch.zeros((sz, sz), device=_device)
+ for i in range(sz):
+ query_idx = perm[i]
+ masked_keys = perm[i + 1:]
+ mask[query_idx, masked_keys] = float('-inf')
+ content_mask = mask[:-1, :-1].clone()
+ mask[torch.eye(sz, dtype=torch.bool,
+ device=_device)] = float('-inf') # mask "self"
+ query_mask = mask[1:, :-1]
+ return content_mask, query_mask
+
+ def training_step(self, batch):
+ memory, pos_query, tgt = batch
+ bs = memory.shape[0]
+ if pos_query is None:
+ pos_query = self.pos_queries.expand(bs, -1, -1)
+
+ # Prepare the target sequences (input and output)
+ tgt_perms = self.gen_tgt_perms(tgt, memory.get_device())
+ tgt_in = tgt[:, :-1]
+ tgt_out = tgt[:, 1:]
+ # The [EOS] token is not depended upon by any other token in any permutation ordering
+ tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
+
+ loss = 0
+ loss_numel = 0
+ n = (tgt_out != self.pad_id).sum().item()
+ for i, perm in enumerate(tgt_perms):
+ tgt_mask, query_mask = self.generate_attn_masks(
+ perm, memory.get_device())
+ out = self.decode(
+ tgt_in,
+ memory,
+ tgt_mask,
+ tgt_padding_mask,
+ tgt_query_mask=query_mask,
+ pos_query=pos_query,
+ )
+ logits = self.head(out)
+ if i == 0:
+ final_out = logits
+ loss += n * F.cross_entropy(logits.flatten(end_dim=1),
+ tgt_out.flatten(),
+ ignore_index=self.pad_id)
+ loss_numel += n
+ # After the second iteration (i.e. done with canonical and reverse orderings),
+ # remove the [EOS] tokens for the succeeding perms
+ if i == 1:
+ tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id,
+ tgt_out)
+ n = (tgt_out != self.pad_id).sum().item()
+ loss /= loss_numel
+
+ # self.log('loss', loss)
+ return [loss, final_out]
diff --git a/openrec/modeling/decoders/rctc_decoder.py b/openrec/modeling/decoders/rctc_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bcdbb29fa410136ffc9e148fab15865112feca3
--- /dev/null
+++ b/openrec/modeling/decoders/rctc_decoder.py
@@ -0,0 +1,70 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.init import trunc_normal_
+
+from openrec.modeling.common import Block
+
+
+class RCTCDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels=6625,
+ return_feats=False,
+ **kwargs):
+ super(RCTCDecoder, self).__init__()
+ self.char_token = nn.Parameter(
+ torch.zeros([1, 1, in_channels], dtype=torch.float32),
+ requires_grad=True,
+ )
+ trunc_normal_(self.char_token, mean=0, std=0.02)
+ self.fc = nn.Linear(
+ in_channels,
+ out_channels,
+ bias=True,
+ )
+ self.fc_kv = nn.Linear(
+ in_channels,
+ 2 * in_channels,
+ bias=True,
+ )
+ self.w_atten_block = Block(dim=in_channels,
+ num_heads=in_channels // 32,
+ mlp_ratio=4.0,
+ qkv_bias=False)
+ self.out_channels = out_channels
+ self.return_feats = return_feats
+
+ def forward(self, x, data=None):
+
+ B, C, H, W = x.shape
+ x = self.w_atten_block(x.permute(0, 2, 3,
+ 1).reshape(-1, W, C)).reshape(
+ B, H, W, C).permute(0, 3, 1, 2)
+ # B, D, 8, 32
+ x_kv = self.fc_kv(x.flatten(2).transpose(1, 2)).reshape(
+ B, H * W, 2, C).permute(2, 0, 3, 1) # 2, b, c, hw
+ x_k, x_v = x_kv.unbind(0) # b, c, hw
+ char_token = self.char_token.tile([B, 1, 1])
+ attn_ctc2d = char_token @ x_k # b, 1, hw
+ attn_ctc2d = attn_ctc2d.reshape([-1, 1, H, W])
+ attn_ctc2d = F.softmax(attn_ctc2d, 2) # b, 1, h, w
+ attn_ctc2d = attn_ctc2d.permute(0, 3, 1, 2) # b, w, 1, h
+ x_v = x_v.reshape(B, C, H, W)
+ # B, W, H, C
+ feats = attn_ctc2d @ x_v.permute(0, 3, 2, 1) # b, w, 1, c
+ feats = feats.squeeze(2) # b, w, c
+
+ predicts = self.fc(feats)
+
+ if self.return_feats:
+ result = (feats, predicts)
+ else:
+ result = predicts
+
+ if not self.training:
+ predicts = F.softmax(predicts, dim=2)
+ result = predicts
+
+ return result
diff --git a/openrec/modeling/decoders/robustscanner_decoder.py b/openrec/modeling/decoders/robustscanner_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..584a768c324bb5b0d703342afd981bcfc3a40a9b
--- /dev/null
+++ b/openrec/modeling/decoders/robustscanner_decoder.py
@@ -0,0 +1,749 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class RobustScannerDecoder(nn.Module):
+
+ def __init__(
+ self,
+ out_channels, # 90 + unknown + start + padding
+ in_channels,
+ enc_outchannles=128,
+ hybrid_dec_rnn_layers=2,
+ hybrid_dec_dropout=0,
+ position_dec_rnn_layers=2,
+ max_len=25,
+ mask=True,
+ encode_value=False,
+ **kwargs):
+ super(RobustScannerDecoder, self).__init__()
+
+ start_idx = out_channels - 2
+ padding_idx = out_channels - 1
+ end_idx = 0
+ # encoder module
+ self.encoder = ChannelReductionEncoder(in_channels=in_channels,
+ out_channels=enc_outchannles)
+ self.max_text_length = max_len + 1
+ self.mask = mask
+ # decoder module
+ self.decoder = Decoder(
+ num_classes=out_channels,
+ dim_input=in_channels,
+ dim_model=enc_outchannles,
+ hybrid_decoder_rnn_layers=hybrid_dec_rnn_layers,
+ hybrid_decoder_dropout=hybrid_dec_dropout,
+ position_decoder_rnn_layers=position_dec_rnn_layers,
+ max_len=max_len + 1,
+ start_idx=start_idx,
+ mask=mask,
+ padding_idx=padding_idx,
+ end_idx=end_idx,
+ encode_value=encode_value)
+
+ def forward(self, inputs, data=None):
+ '''
+ data: [label, valid_ratio, 'length']
+ '''
+ out_enc = self.encoder(inputs)
+ bs = out_enc.shape[0]
+ valid_ratios = None
+ word_positions = torch.arange(0,
+ self.max_text_length,
+ device=inputs.device).unsqueeze(0).tile(
+ [bs, 1])
+
+ if self.mask:
+ valid_ratios = data[-1]
+
+ if self.training:
+ max_len = data[1].max()
+ label = data[0][:, :1 + max_len] # label
+ final_out = self.decoder(inputs, out_enc, label, valid_ratios,
+ word_positions[:, :1 + max_len])
+ if not self.training:
+ final_out = self.decoder(inputs,
+ out_enc,
+ label=None,
+ valid_ratios=valid_ratios,
+ word_positions=word_positions,
+ train_mode=False)
+ return final_out
+
+
+class BaseDecoder(nn.Module):
+
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ def forward_train(self, feat, out_enc, targets, img_metas):
+ raise NotImplementedError
+
+ def forward_test(self, feat, out_enc, img_metas):
+ raise NotImplementedError
+
+ def forward(self,
+ feat,
+ out_enc,
+ label=None,
+ valid_ratios=None,
+ word_positions=None,
+ train_mode=True):
+ self.train_mode = train_mode
+
+ if train_mode:
+ return self.forward_train(feat, out_enc, label, valid_ratios,
+ word_positions)
+ return self.forward_test(feat, out_enc, valid_ratios, word_positions)
+
+
+class ChannelReductionEncoder(nn.Module):
+ """Change the channel number with a one by one convoluational layer.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ """
+
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(ChannelReductionEncoder, self).__init__()
+
+ weight = torch.nn.Parameter(
+ torch.nn.init.xavier_normal_(torch.empty(out_channels, in_channels,
+ 1, 1),
+ gain=1.0))
+ self.layer = nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ use_xavier_normal = 1
+ if use_xavier_normal:
+ self.layer.weight = weight
+
+ def forward(self, feat):
+ """
+ Args:
+ feat (Tensor): Image features with the shape of
+ :math:`(N, C_{in}, H, W)`.
+
+ Returns:
+ Tensor: A tensor of shape :math:`(N, C_{out}, H, W)`.
+ """
+ return self.layer(feat)
+
+
+def masked_fill(x, mask, value):
+ y = torch.full(x.shape, value, x.dtype)
+ return torch.where(mask, y, x)
+
+
+class DotProductAttentionLayer(nn.Module):
+
+ def __init__(self, dim_model=None):
+ super().__init__()
+
+ self.scale = dim_model**-0.5 if dim_model is not None else 1.
+
+ def forward(self, query, key, value, mask=None):
+
+ query = query.permute(0, 2, 1)
+ logits = query @ key * self.scale
+
+ if mask is not None:
+ n, seq_len = mask.size()
+ mask = mask.view(n, 1, seq_len)
+ logits = logits.masked_fill(mask, float('-inf'))
+
+ weights = F.softmax(logits, dim=2)
+ value = value.transpose(1, 2)
+ glimpse = weights @ value
+ glimpse = glimpse.permute(0, 2, 1).contiguous()
+ return glimpse
+
+
+class SequenceAttentionDecoder(BaseDecoder):
+ """Sequence attention decoder for RobustScanner.
+
+ RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
+ Robust Text Recognition `_
+
+ Args:
+ num_classes (int): Number of output classes :math:`C`.
+ rnn_layers (int): Number of RNN layers.
+ dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
+ dim_model (int): Dimension :math:`D_m` of the model. Should also be the
+ same as encoder output vector ``out_enc``.
+ max_seq_len (int): Maximum output sequence length :math:`T`.
+ start_idx (int): The index of ``.
+ mask (bool): Whether to mask input features according to
+ ``img_meta['valid_ratio']``.
+ padding_idx (int): The index of ``.
+ dropout (float): Dropout rate.
+ return_feature (bool): Return feature or logits as the result.
+ encode_value (bool): Whether to use the output of encoder ``out_enc``
+ as `value` of attention layer. If False, the original feature
+ ``feat`` will be used.
+
+ Warning:
+ This decoder will not predict the final class which is assumed to be
+ ``. Therefore, its output size is always :math:`C - 1`. ``
+ is also ignored by loss as specified in
+ :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
+ """
+
+ def __init__(self,
+ num_classes=None,
+ rnn_layers=2,
+ dim_input=512,
+ dim_model=128,
+ max_seq_len=40,
+ start_idx=0,
+ mask=True,
+ padding_idx=None,
+ dropout=0,
+ return_feature=False,
+ encode_value=False):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.dim_input = dim_input
+ self.dim_model = dim_model
+ self.return_feature = return_feature
+ self.encode_value = encode_value
+ self.max_seq_len = max_seq_len
+ self.start_idx = start_idx
+ self.mask = mask
+
+ self.embedding = nn.Embedding(self.num_classes,
+ self.dim_model,
+ padding_idx=padding_idx)
+
+ self.sequence_layer = nn.LSTM(input_size=dim_model,
+ hidden_size=dim_model,
+ num_layers=rnn_layers,
+ batch_first=True,
+ dropout=dropout)
+
+ self.attention_layer = DotProductAttentionLayer()
+
+ self.prediction = None
+ if not self.return_feature:
+ pred_num_classes = num_classes - 1
+ self.prediction = nn.Linear(
+ dim_model if encode_value else dim_input, pred_num_classes)
+
+ def forward_train(self, feat, out_enc, targets, valid_ratios):
+ """
+ Args:
+ feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
+ out_enc (Tensor): Encoder output of shape
+ :math:`(N, D_m, H, W)`.
+ targets (Tensor): a tensor of shape :math:`(N, T)`. Each element is the index of a
+ character.
+ valid_ratios (Tensor): valid length ratio of img.
+ Returns:
+ Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
+ ``return_feature=False``. Otherwise it would be the hidden feature
+ before the prediction projection layer, whose shape is
+ :math:`(N, T, D_m)`.
+ """
+
+ tgt_embedding = self.embedding(targets)
+
+ n, c_enc, h, w = out_enc.shape
+ assert c_enc == self.dim_model
+ _, c_feat, _, _ = feat.shape
+ assert c_feat == self.dim_input
+ _, len_q, c_q = tgt_embedding.shape
+ assert c_q == self.dim_model
+ assert len_q <= self.max_seq_len
+
+ query, _ = self.sequence_layer(tgt_embedding)
+
+ query = query.permute(0, 2, 1).contiguous()
+
+ key = out_enc.view(n, c_enc, h * w)
+
+ if self.encode_value:
+ value = key
+ else:
+ value = feat.view(n, c_feat, h * w)
+
+ mask = None
+ if valid_ratios is not None:
+ mask = query.new_zeros((n, h, w))
+ for i, valid_ratio in enumerate(valid_ratios):
+ valid_width = min(w, math.ceil(w * valid_ratio))
+ mask[i, :, valid_width:] = 1
+ mask = mask.bool()
+ mask = mask.view(n, h * w)
+
+ attn_out = self.attention_layer(query, key, value, mask)
+ attn_out = attn_out.permute(0, 2, 1).contiguous()
+
+ if self.return_feature:
+ return attn_out
+
+ out = self.prediction(attn_out)
+
+ return out
+
+ def forward_test(self, feat, out_enc, valid_ratios):
+ """
+ Args:
+ feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
+ out_enc (Tensor): Encoder output of shape
+ :math:`(N, D_m, H, W)`.
+ valid_ratios (Tensor): valid length ratio of img.
+
+ Returns:
+ Tensor: The output logit sequence tensor of shape
+ :math:`(N, T, C-1)`.
+ """
+ batch_size = feat.shape[0]
+
+ decode_sequence = (torch.ones((batch_size, self.max_seq_len),
+ dtype=torch.int64,
+ device=feat.device) * self.start_idx)
+
+ outputs = []
+ for i in range(self.max_seq_len):
+ step_out = self.forward_test_step(feat, out_enc, decode_sequence,
+ i, valid_ratios)
+ outputs.append(step_out)
+ max_idx = torch.argmax(step_out, dim=1, keepdim=False)
+ if i < self.max_seq_len - 1:
+ decode_sequence[:, i + 1] = max_idx
+
+ outputs = torch.stack(outputs, 1)
+
+ return outputs
+
+ def forward_test_step(self, feat, out_enc, decode_sequence, current_step,
+ valid_ratios):
+ """
+ Args:
+ feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
+ out_enc (Tensor): Encoder output of shape
+ :math:`(N, D_m, H, W)`.
+ decode_sequence (Tensor): Shape :math:`(N, T)`. The tensor that
+ stores history decoding result.
+ current_step (int): Current decoding step.
+ valid_ratios (Tensor): valid length ratio of img
+
+ Returns:
+ Tensor: Shape :math:`(N, C-1)`. The logit tensor of predicted
+ tokens at current time step.
+ """
+
+ embed = self.embedding(decode_sequence)
+
+ n, c_enc, h, w = out_enc.shape
+ assert c_enc == self.dim_model
+ _, c_feat, _, _ = feat.shape
+ assert c_feat == self.dim_input
+ _, _, c_q = embed.shape
+ assert c_q == self.dim_model
+
+ query, _ = self.sequence_layer(embed)
+ query = query.transpose(1, 2)
+ key = torch.reshape(out_enc, (n, c_enc, h * w))
+ if self.encode_value:
+ value = key
+ else:
+ value = torch.reshape(feat, (n, c_feat, h * w))
+
+ mask = None
+ if valid_ratios is not None:
+ mask = query.new_zeros((n, h, w))
+ for i, valid_ratio in enumerate(valid_ratios):
+ valid_width = min(w, math.ceil(w * valid_ratio))
+ mask[i, :, valid_width:] = 1
+ mask = mask.bool()
+ mask = mask.view(n, h * w)
+
+ # [n, c, l]
+ attn_out = self.attention_layer(query, key, value, mask)
+ out = attn_out[:, :, current_step]
+
+ if self.return_feature:
+ return out
+
+ out = self.prediction(out)
+ out = F.softmax(out, dim=-1)
+
+ return out
+
+
+class PositionAwareLayer(nn.Module):
+
+ def __init__(self, dim_model, rnn_layers=2):
+ super().__init__()
+
+ self.dim_model = dim_model
+
+ self.rnn = nn.LSTM(input_size=dim_model,
+ hidden_size=dim_model,
+ num_layers=rnn_layers,
+ batch_first=True)
+
+ self.mixer = nn.Sequential(
+ nn.Conv2d(dim_model, dim_model, kernel_size=3, stride=1,
+ padding=1), nn.ReLU(True),
+ nn.Conv2d(dim_model, dim_model, kernel_size=3, stride=1,
+ padding=1))
+
+ def forward(self, img_feature):
+ n, c, h, w = img_feature.shape
+ rnn_input = img_feature.permute(0, 2, 3, 1).contiguous()
+ rnn_input = rnn_input.view(n * h, w, c)
+ rnn_output, _ = self.rnn(rnn_input)
+ rnn_output = rnn_output.view(n, h, w, c)
+ rnn_output = rnn_output.permute(0, 3, 1, 2).contiguous()
+
+ out = self.mixer(rnn_output)
+ return out
+
+
+class PositionAttentionDecoder(BaseDecoder):
+ """Position attention decoder for RobustScanner.
+
+ RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
+ Robust Text Recognition `_
+
+ Args:
+ num_classes (int): Number of output classes :math:`C`.
+ rnn_layers (int): Number of RNN layers.
+ dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
+ dim_model (int): Dimension :math:`D_m` of the model. Should also be the
+ same as encoder output vector ``out_enc``.
+ max_seq_len (int): Maximum output sequence length :math:`T`.
+ mask (bool): Whether to mask input features according to
+ ``img_meta['valid_ratio']``.
+ return_feature (bool): Return feature or logits as the result.
+ encode_value (bool): Whether to use the output of encoder ``out_enc``
+ as `value` of attention layer. If False, the original feature
+ ``feat`` will be used.
+
+ Warning:
+ This decoder will not predict the final class which is assumed to be
+ ``. Therefore, its output size is always :math:`C - 1`. ``
+ is also ignored by loss
+ """
+
+ def __init__(self,
+ num_classes=None,
+ rnn_layers=2,
+ dim_input=512,
+ dim_model=128,
+ max_seq_len=40,
+ mask=True,
+ return_feature=False,
+ encode_value=False):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.dim_input = dim_input
+ self.dim_model = dim_model
+ self.max_seq_len = max_seq_len
+ self.return_feature = return_feature
+ self.encode_value = encode_value
+ self.mask = mask
+
+ self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model)
+
+ self.position_aware_module = PositionAwareLayer(
+ self.dim_model, rnn_layers)
+
+ self.attention_layer = DotProductAttentionLayer()
+
+ self.prediction = None
+ if not self.return_feature:
+ pred_num_classes = num_classes - 1
+ self.prediction = nn.Linear(
+ dim_model if encode_value else dim_input, pred_num_classes)
+
+ def _get_position_index(self, length, batch_size):
+ position_index_list = []
+ for i in range(batch_size):
+ position_index = torch.range(0, length, step=1, dtype='int64')
+ position_index_list.append(position_index)
+ batch_position_index = torch.stack(position_index_list, dim=0)
+ return batch_position_index
+
+ def forward_train(self, feat, out_enc, targets, valid_ratios,
+ position_index):
+ """
+ Args:
+ feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
+ out_enc (Tensor): Encoder output of shape
+ :math:`(N, D_m, H, W)`.
+ targets (dict): A dict with the key ``padded_targets``, a
+ tensor of shape :math:`(N, T)`. Each element is the index of a
+ character.
+ valid_ratios (Tensor): valid length ratio of img.
+ position_index (Tensor): The position of each word.
+
+ Returns:
+ Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
+ ``return_feature=False``. Otherwise it will be the hidden feature
+ before the prediction projection layer, whose shape is
+ :math:`(N, T, D_m)`.
+ """
+ n, c_enc, h, w = out_enc.shape
+ assert c_enc == self.dim_model
+ _, c_feat, _, _ = feat.shape
+ assert c_feat == self.dim_input
+ _, len_q = targets.shape
+ assert len_q <= self.max_seq_len
+
+ position_out_enc = self.position_aware_module(out_enc)
+
+ query = self.embedding(position_index)
+ query = query.permute(0, 2, 1).contiguous()
+ key = position_out_enc.view(n, c_enc, h * w)
+ if self.encode_value:
+ value = out_enc.view(n, c_enc, h * w)
+ else:
+ value = feat.view(n, c_feat, h * w)
+
+ mask = None
+ if valid_ratios is not None:
+ mask = query.new_zeros((n, h, w))
+ for i, valid_ratio in enumerate(valid_ratios):
+ valid_width = min(w, math.ceil(w * valid_ratio))
+ mask[i, :, valid_width:] = 1
+ mask = mask.bool()
+ mask = mask.view(n, h * w)
+
+ attn_out = self.attention_layer(query, key, value, mask)
+ attn_out = attn_out.permute(0, 2, 1).contiguous()
+
+ if self.return_feature:
+ return attn_out
+
+ return self.prediction(attn_out)
+
+ def forward_test(self, feat, out_enc, valid_ratios, position_index):
+ """
+ Args:
+ feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
+ out_enc (Tensor): Encoder output of shape
+ :math:`(N, D_m, H, W)`.
+ valid_ratios (Tensor): valid length ratio of img
+ position_index (Tensor): The position of each word.
+
+ Returns:
+ Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
+ ``return_feature=False``. Otherwise it would be the hidden feature
+ before the prediction projection layer, whose shape is
+ :math:`(N, T, D_m)`.
+ """
+ n, c_enc, h, w = out_enc.shape
+ assert c_enc == self.dim_model
+ _, c_feat, _, _ = feat.shape
+ assert c_feat == self.dim_input
+
+ position_out_enc = self.position_aware_module(out_enc)
+
+ query = self.embedding(position_index)
+ query = query.permute(0, 2, 1).contiguous()
+ key = position_out_enc.view(n, c_enc, h * w)
+ if self.encode_value:
+ value = torch.reshape(out_enc, (n, c_enc, h * w))
+ else:
+ value = torch.reshape(feat, (n, c_feat, h * w))
+
+ mask = None
+ if valid_ratios is not None:
+ mask = query.new_zeros((n, h, w))
+ for i, valid_ratio in enumerate(valid_ratios):
+ valid_width = min(w, math.ceil(w * valid_ratio))
+ mask[i, :, valid_width:] = 1
+ mask = mask.bool()
+ mask = mask.view(n, h * w)
+
+ attn_out = self.attention_layer(query, key, value, mask)
+ attn_out = attn_out.transpose(1, 2) # [n, len_q, dim_v]
+
+ if self.return_feature:
+ return attn_out
+
+ return self.prediction(attn_out)
+
+
+class RobustScannerFusionLayer(nn.Module):
+
+ def __init__(self, dim_model, dim=-1):
+ super(RobustScannerFusionLayer, self).__init__()
+
+ self.dim_model = dim_model
+ self.dim = dim
+ self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2)
+
+ def forward(self, x0, x1):
+ assert x0.shape == x1.shape
+ fusion_input = torch.concat((x0, x1), self.dim)
+ output = self.linear_layer(fusion_input)
+ output = F.glu(output, self.dim)
+
+ return output
+
+
+class Decoder(BaseDecoder):
+ """Decoder for RobustScanner.
+
+ RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
+ Robust Text Recognition `_
+
+ Args:
+ num_classes (int): Number of output classes :math:`C`.
+ dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
+ dim_model (int): Dimension :math:`D_m` of the model. Should also be the
+ same as encoder output vector ``out_enc``.
+ max_seq_len (int): Maximum output sequence length :math:`T`.
+ start_idx (int): The index of ``.
+ mask (bool): Whether to mask input features according to
+ ``img_meta['valid_ratio']``.
+ padding_idx (int): The index of ``.
+ encode_value (bool): Whether to use the output of encoder ``out_enc``
+ as `value` of attention layer. If False, the original feature
+ ``feat`` will be used.
+
+ Warning:
+ This decoder will not predict the final class which is assumed to be
+ ``. Therefore, its output size is always :math:`C - 1`. ``
+ is also ignored by loss as specified in
+ :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
+ """
+
+ def __init__(self,
+ num_classes=None,
+ dim_input=512,
+ dim_model=128,
+ hybrid_decoder_rnn_layers=2,
+ hybrid_decoder_dropout=0,
+ position_decoder_rnn_layers=2,
+ max_len=40,
+ start_idx=0,
+ mask=True,
+ padding_idx=None,
+ end_idx=0,
+ encode_value=False):
+ super().__init__()
+ self.num_classes = num_classes
+ self.dim_input = dim_input
+ self.dim_model = dim_model
+ self.max_seq_len = max_len
+ self.encode_value = encode_value
+ self.start_idx = start_idx
+ self.padding_idx = padding_idx
+ self.end_idx = end_idx
+ self.mask = mask
+
+ # init hybrid decoder
+ self.hybrid_decoder = SequenceAttentionDecoder(
+ num_classes=num_classes,
+ rnn_layers=hybrid_decoder_rnn_layers,
+ dim_input=dim_input,
+ dim_model=dim_model,
+ max_seq_len=max_len,
+ start_idx=start_idx,
+ mask=mask,
+ padding_idx=padding_idx,
+ dropout=hybrid_decoder_dropout,
+ encode_value=encode_value,
+ return_feature=True)
+
+ # init position decoder
+ self.position_decoder = PositionAttentionDecoder(
+ num_classes=num_classes,
+ rnn_layers=position_decoder_rnn_layers,
+ dim_input=dim_input,
+ dim_model=dim_model,
+ max_seq_len=max_len,
+ mask=mask,
+ encode_value=encode_value,
+ return_feature=True)
+
+ self.fusion_module = RobustScannerFusionLayer(
+ self.dim_model if encode_value else dim_input)
+
+ pred_num_classes = num_classes
+ self.prediction = nn.Linear(dim_model if encode_value else dim_input,
+ pred_num_classes)
+
+ def forward_train(self, feat, out_enc, target, valid_ratios,
+ word_positions):
+ """
+ Args:
+ feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
+ out_enc (Tensor): Encoder output of shape
+ :math:`(N, D_m, H, W)`.
+ target (dict): A dict with the key ``padded_targets``, a
+ tensor of shape :math:`(N, T)`. Each element is the index of a
+ character.
+ valid_ratios (Tensor):
+ word_positions (Tensor): The position of each word.
+
+ Returns:
+ Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`.
+ """
+
+ hybrid_glimpse = self.hybrid_decoder.forward_train(
+ feat, out_enc, target, valid_ratios)
+ position_glimpse = self.position_decoder.forward_train(
+ feat, out_enc, target, valid_ratios, word_positions)
+
+ fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse)
+
+ out = self.prediction(fusion_out)
+
+ return out
+
+ def forward_test(self, feat, out_enc, valid_ratios, word_positions):
+ """
+ Args:
+ feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
+ out_enc (Tensor): Encoder output of shape
+ :math:`(N, D_m, H, W)`.
+ valid_ratios (Tensor):
+ word_positions (Tensor): The position of each word.
+ Returns:
+ Tensor: The output logit sequence tensor of shape
+ :math:`(N, T, C-1)`.
+ """
+ seq_len = self.max_seq_len
+ batch_size = feat.shape[0]
+
+ decode_sequence = (torch.ones(
+ (batch_size, seq_len), dtype=torch.int64, device=feat.device) *
+ self.start_idx)
+
+ position_glimpse = self.position_decoder.forward_test(
+ feat, out_enc, valid_ratios, word_positions)
+
+ outputs = []
+ for i in range(seq_len):
+ hybrid_glimpse_step = self.hybrid_decoder.forward_test_step(
+ feat, out_enc, decode_sequence, i, valid_ratios)
+
+ fusion_out = self.fusion_module(hybrid_glimpse_step,
+ position_glimpse[:, i, :])
+
+ char_out = self.prediction(fusion_out)
+ char_out = F.softmax(char_out, -1)
+ outputs.append(char_out)
+ max_idx = torch.argmax(char_out, dim=1, keepdim=False)
+ if i < seq_len - 1:
+ decode_sequence[:, i + 1] = max_idx
+ if (decode_sequence == self.end_idx).any(dim=-1).all():
+ break
+ outputs = torch.stack(outputs, 1)
+
+ return outputs
diff --git a/openrec/modeling/decoders/sar_decoder.py b/openrec/modeling/decoders/sar_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a439817108ff734c0fc1ff44c41a8152df1fc4b5
--- /dev/null
+++ b/openrec/modeling/decoders/sar_decoder.py
@@ -0,0 +1,236 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class SAREncoder(nn.Module):
+
+ def __init__(self,
+ enc_bi_rnn=False,
+ enc_drop_rnn=0.1,
+ in_channels=512,
+ d_enc=512,
+ **kwargs):
+ super().__init__()
+
+ # LSTM Encoder
+ if enc_bi_rnn:
+ bidirectional = True
+ else:
+ bidirectional = False
+
+ hidden_size = d_enc
+
+ self.rnn_encoder = nn.LSTM(input_size=in_channels,
+ hidden_size=hidden_size,
+ num_layers=2,
+ dropout=enc_drop_rnn,
+ bidirectional=bidirectional,
+ batch_first=True)
+
+ # global feature transformation
+ encoder_rnn_out_size = hidden_size * (int(enc_bi_rnn) + 1)
+ self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
+
+ def forward(self, feat):
+
+ h_feat = feat.shape[2]
+ feat_v = F.max_pool2d(feat,
+ kernel_size=(h_feat, 1),
+ stride=1,
+ padding=0)
+ feat_v = feat_v.squeeze(2)
+ feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C
+
+ holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * hidden_size
+
+ valid_hf = holistic_feat[:, -1, :] # bsz * hidden_size
+
+ holistic_feat = self.linear(valid_hf) # bsz * C
+
+ return holistic_feat
+
+
+class SARDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ max_len=25,
+ enc_bi_rnn=False,
+ enc_drop_rnn=0.1,
+ dec_bi_rnn=False,
+ dec_drop_rnn=0.0,
+ pred_dropout=0.1,
+ pred_concat=True,
+ mask=True,
+ use_lstm=True,
+ **kwargs):
+ super(SARDecoder, self).__init__()
+
+ self.num_classes = out_channels
+ self.start_idx = out_channels - 2
+ self.padding_idx = out_channels - 1
+ self.end_idx = 0
+ self.max_seq_len = max_len + 1
+ self.pred_concat = pred_concat
+ self.mask = mask
+ enc_dim = in_channels
+ d = in_channels
+ embedding_dim = in_channels
+ dec_dim = in_channels
+ self.use_lstm = use_lstm
+ if use_lstm:
+ # encoder module
+ self.encoder = SAREncoder(enc_bi_rnn=enc_bi_rnn,
+ enc_drop_rnn=enc_drop_rnn,
+ in_channels=in_channels,
+ d_enc=enc_dim)
+
+ # decoder module
+
+ # 2D attention layer
+ self.conv1x1_1 = nn.Linear(dec_dim, d)
+ self.conv3x3_1 = nn.Conv2d(in_channels,
+ d,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.conv1x1_2 = nn.Linear(d, 1)
+
+ # Decoder input embedding
+ self.embedding = nn.Embedding(self.num_classes,
+ embedding_dim,
+ padding_idx=self.padding_idx)
+
+ self.rnndecoder = nn.LSTM(input_size=embedding_dim,
+ hidden_size=dec_dim,
+ num_layers=2,
+ dropout=dec_drop_rnn,
+ bidirectional=dec_bi_rnn,
+ batch_first=True)
+
+ # Prediction layer
+ self.pred_dropout = nn.Dropout(pred_dropout)
+ if pred_concat:
+ fc_in_channel = in_channels + in_channels + dec_dim
+ else:
+ fc_in_channel = in_channels
+ self.prediction = nn.Linear(fc_in_channel, self.num_classes)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def _2d_attation(self, feat, tokens, data, training):
+
+ Hidden_state = self.rnndecoder(tokens)[0]
+ attn_query = self.conv1x1_1(Hidden_state)
+ bsz, seq_len, _ = attn_query.size()
+ attn_query = attn_query.unsqueeze(-1).unsqueeze(-1)
+ # bsz * seq_len+1 * attn_size * 1 * 1
+ attn_key = self.conv3x3_1(feat).unsqueeze(1)
+ # bsz * 1 * attn_size * h * w
+
+ attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1))
+ attn_weight = attn_weight.permute(0, 1, 3, 4, 2).contiguous()
+ attn_weight = self.conv1x1_2(attn_weight)
+
+ _, T, h, w, c = attn_weight.size()
+
+ if self.mask:
+ valid_ratios = data[-1]
+ # cal mask of attention weight
+ attn_mask = torch.zeros_like(attn_weight)
+ for i, valid_ratio in enumerate(valid_ratios):
+ valid_width = min(w, math.ceil(w * valid_ratio))
+ attn_mask[i, :, :, valid_width:, :] = 1
+ attn_weight = attn_weight.masked_fill(attn_mask.bool(),
+ float('-inf'))
+
+ attn_weight = attn_weight.view(bsz, T, -1)
+ attn_weight = F.softmax(attn_weight, dim=-1)
+ attn_weight = attn_weight.view(bsz, T, h, w,
+ c).permute(0, 1, 4, 2, 3).contiguous()
+ # bsz, T, 1, h, w
+ # bsz, 1, f_c ,h, w
+ attn_feat = torch.sum(torch.mul(feat.unsqueeze(1), attn_weight),
+ (3, 4),
+ keepdim=False)
+ return [Hidden_state, attn_feat]
+
+ def forward_train(self, feat, holistic_feat, data):
+
+ max_len = data[1].max()
+ label = data[0][:, :1 + max_len] # label
+ label_embedding = self.embedding(label)
+ holistic_feat = holistic_feat.unsqueeze(1)
+ tokens = torch.cat((holistic_feat, label_embedding), dim=1)
+
+ Hidden_state, attn_feat = self._2d_attation(feat,
+ tokens,
+ data,
+ training=self.training)
+
+ bsz, seq_len, f_c = Hidden_state.size()
+ # linear transformation
+ if self.pred_concat:
+ f_c = holistic_feat.size(-1)
+ holistic_feat = holistic_feat.expand(bsz, seq_len, f_c)
+ preds = self.prediction(
+ torch.cat((Hidden_state, attn_feat, holistic_feat), 2))
+ else:
+ preds = self.prediction(attn_feat)
+ # bsz * (seq_len + 1) * num_classes
+ preds = self.pred_dropout(preds)
+ return preds[:, 1:, :]
+
+ def forward_test(self, feat, holistic_feat, data=None):
+ bsz = feat.shape[0]
+ seq_len = self.max_seq_len
+ holistic_feat = holistic_feat.unsqueeze(1)
+ tokens = torch.full((bsz, ),
+ self.start_idx,
+ device=feat.device,
+ dtype=torch.long)
+ outputs = []
+ tokens = self.embedding(tokens)
+ tokens = tokens.unsqueeze(1).expand(-1, seq_len, -1)
+ tokens = torch.cat((holistic_feat, tokens), dim=1)
+ for i in range(1, seq_len + 1):
+ Hidden_state, attn_feat = self._2d_attation(feat,
+ tokens,
+ data=data,
+ training=self.training)
+ if self.pred_concat:
+ f_c = holistic_feat.size(-1)
+ holistic_feat = holistic_feat.expand(bsz, seq_len + 1, f_c)
+ preds = self.prediction(
+ torch.cat((Hidden_state, attn_feat, holistic_feat), 2))
+ else:
+ preds = self.prediction(attn_feat)
+ # bsz * (seq_len + 1) * num_classes
+ char_output = preds[:, i, :]
+ char_output = F.softmax(char_output, -1)
+ outputs.append(char_output)
+ _, max_idx = torch.max(char_output, dim=1, keepdim=False)
+ char_embedding = self.embedding(max_idx)
+ if (i < seq_len):
+ tokens[:, i + 1, :] = char_embedding
+ if (tokens == self.end_idx).any(dim=-1).all():
+ break
+ outputs = torch.stack(outputs, 1)
+
+ return outputs
+
+ def forward(self, feat, data=None):
+ if self.use_lstm:
+ holistic_feat = self.encoder(feat) # bsz c
+ else:
+ holistic_feat = F.adaptive_avg_pool2d(feat, (1, 1)).squeeze()
+
+ if self.training:
+ preds = self.forward_train(feat, holistic_feat, data=data)
+ else:
+ preds = self.forward_test(feat, holistic_feat, data=data)
+ # (bsz, seq_len, num_classes)
+ return preds
diff --git a/openrec/modeling/decoders/smtr_decoder.py b/openrec/modeling/decoders/smtr_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..01bd7eedd88f809f519087f66ede8b82ba5cad32
--- /dev/null
+++ b/openrec/modeling/decoders/smtr_decoder.py
@@ -0,0 +1,621 @@
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.init import ones_, trunc_normal_, zeros_
+
+from openrec.modeling.common import DropPath, Identity
+from openrec.modeling.decoders.nrtr_decoder import Embeddings
+
+
+class CrossAttention(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, q, kv, key_mask=None):
+ N, C = kv.shape[1:]
+ QN = q.shape[1]
+ q = self.q(q).reshape([-1, QN, self.num_heads,
+ C // self.num_heads]).transpose(1, 2)
+ q = q * self.scale
+ k, v = self.kv(kv).reshape(
+ [-1, N, 2, self.num_heads,
+ C // self.num_heads]).permute(2, 0, 3, 1, 4)
+
+ attn = q.matmul(k.transpose(2, 3))
+
+ if key_mask is not None:
+ attn = attn + key_mask.unsqueeze(1)
+
+ attn = F.softmax(attn, -1)
+ if not self.training:
+ self.attn_map = attn
+ attn = self.attn_drop(attn)
+
+ x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C))
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SSMatchLayer(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ nextq2subs_head2=None,
+ dynq2img_heads=2,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ epsilon=1e-6,
+ ):
+ super().__init__()
+ self.dim = dim
+ if nextq2subs_head2 is None:
+ nextq2subs_head2 = dim // 32
+ self.normq1 = nn.LayerNorm(dim, eps=epsilon)
+ self.normkv1 = nn.LayerNorm(dim, eps=epsilon)
+ self.images_to_question_cross_attn = CrossAttention(
+ dim,
+ num_heads=nextq2subs_head2,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop)
+ self.normq2 = nn.LayerNorm(dim, eps=epsilon)
+ self.normkv2 = nn.LayerNorm(dim, eps=epsilon)
+ self.question_to_images_cross_attn = CrossAttention(
+ dim,
+ num_heads=dynq2img_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+
+ def forward(self, question_f, prompt_f, visual_f, mask=None):
+
+ question_f = question_f + self.drop_path(
+ self.images_to_question_cross_attn(self.normq1(question_f),
+ self.normkv1(prompt_f), mask))
+ question_f = question_f.reshape(visual_f.shape[0], -1, self.dim)
+ question_f = self.question_to_images_cross_attn(
+ self.normq2(question_f), self.normkv2(visual_f))
+
+ return question_f
+
+
+class SMTRDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_layer=2,
+ nextq2subs_head2=None,
+ dynq2img_heads=2,
+ drop_path_rate=0.1,
+ max_len=25,
+ vis_seq=50,
+ ds=False,
+ pos2d=False,
+ max_size=[8, 32],
+ sub_str_len=5,
+ next_mode=True,
+ infer_aug=False,
+ bi_attn=False,
+ **kwargs):
+ super(SMTRDecoder, self).__init__()
+
+ self.out_channels = out_channels
+ dim = in_channels
+ self.dim = dim
+ self.max_len = max_len + 3 # max_len + eos + bos
+ self.char_embed = Embeddings(d_model=dim,
+ vocab=self.out_channels,
+ scale_embedding=True)
+ self.ignore_index = out_channels - 1
+ self.sub_str_len = sub_str_len
+ self.bos_next = out_channels - 3
+ self.bos_pre = out_channels - 2
+ self.eos = 0
+ dpr = np.linspace(0, drop_path_rate, num_layer + 2)
+ self.next_mode = next_mode
+ self.infer_aug = infer_aug
+ self.bi_attn = bi_attn
+ self.cmff_decoder = nn.ModuleList([
+ SSMatchLayer(dim=dim,
+ nextq2subs_head2=nextq2subs_head2,
+ dynq2img_heads=dynq2img_heads,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=dpr[i]) for i in range(num_layer)
+ ])
+
+ self.ds = ds
+ self.pos2d = pos2d
+ if not ds:
+ self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim],
+ dtype=torch.float32),
+ requires_grad=True)
+ trunc_normal_(self.vis_pos_embed, std=0.02)
+ elif pos2d:
+ pos_embed = torch.zeros([1, max_size[0] * max_size[1], dim],
+ dtype=torch.float32)
+ trunc_normal_(pos_embed, mean=0, std=0.02)
+ self.vis_pos_embed = nn.Parameter(pos_embed.transpose(
+ 1, 2).reshape(1, dim, max_size[0], max_size[1]),
+ requires_grad=True)
+
+ self.next_token = nn.Parameter(torch.zeros([1, 1, 1, dim],
+ dtype=torch.float32),
+ requires_grad=True)
+
+ self.pre_token = nn.Parameter(torch.zeros([1, 1, 1, dim],
+ dtype=torch.float32),
+ requires_grad=True)
+
+ self.prompt_next_embed = nn.Parameter(torch.zeros(
+ [1, 1, self.sub_str_len + 1, dim], dtype=torch.float32),
+ requires_grad=True)
+
+ self.prompt_pre_embed = nn.Parameter(torch.zeros(
+ [1, 1, self.sub_str_len + 1, dim], dtype=torch.float32),
+ requires_grad=True)
+
+ self.norm_pred = nn.LayerNorm(dim, eps=1e-6)
+ self.ques1_head = nn.Linear(dim, self.out_channels - 3)
+
+ trunc_normal_(self.next_token, std=0.02)
+ trunc_normal_(self.pre_token, std=0.02)
+ trunc_normal_(self.prompt_pre_embed, std=0.02)
+ trunc_normal_(self.prompt_next_embed, std=0.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'vis_pos_embed', 'pre_token', 'next_token', 'char_embed'}
+
+ def forward(self, x, data=None):
+ if self.training:
+ return self.forward_train(x, data)
+ else:
+ if self.infer_aug:
+ if self.bi_attn:
+ return self.forward_test_bi_attn(x)
+ return self.forward_test_bi(x)
+ return self.forward_test(x)
+
+ def forward_test_bi(self, x):
+ # self.attn_maps = []
+ if not self.ds:
+ visual_f = x + self.vis_pos_embed
+ elif self.pos2d:
+ visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
+ visual_f = x.flatten(2).transpose(1, 2)
+ else:
+ visual_f = x
+ bs = 2
+ if 1:
+ next = self.next_token
+ pre = self.pre_token
+ next_pre = torch.concat([next, pre], 0)
+ next_pre = next_pre.squeeze(1) #2, 1, dim
+
+ prompt_next_embed = self.prompt_next_embed.squeeze(1)
+ prompt_pre_embed = self.prompt_pre_embed.squeeze(1)
+
+ next_id = torch.full([1, self.sub_str_len],
+ self.bos_next,
+ dtype=torch.long,
+ device=x.get_device())
+ pre_id = torch.full([1, self.sub_str_len],
+ self.bos_pre,
+ dtype=torch.long,
+ device=x.get_device())
+ # prompt_next_bos = self.char_embed(prompt_id)
+ # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.get_device())
+ next_pred_id_list = torch.full([1, self.max_len],
+ self.ignore_index,
+ dtype=torch.long,
+ device=x.get_device())
+ pre_pred_id_list = torch.full([1, self.max_len],
+ self.ignore_index,
+ dtype=torch.long,
+ device=x.get_device())
+ next_logits_all = []
+ pre_logits_all = []
+ mask_pad = torch.zeros([bs, 1],
+ dtype=torch.float32,
+ device=x.get_device())
+ for j in range(0, min(70, self.max_len - 1)):
+
+ prompt_char_next = torch.concat([
+ prompt_next_embed[:, :1, :],
+ prompt_next_embed[:, 1:, :] + self.char_embed(next_id)
+ ], 1) # b, sub_l, dim
+ prompt_char_pre = torch.concat([
+ prompt_pre_embed[:, :1, :],
+ prompt_pre_embed[:, 1:, :] + self.char_embed(pre_id)
+ ], 1) # b, sub_l, dim
+ prompt_char = torch.concat([prompt_char_next, prompt_char_pre],
+ 0) #2, 6, dim
+ # prompt_char = prompt_char.flatten(0, 1)
+
+ mask_next = torch.where(next_id == self.bos_next,
+ float('-inf'), 0) # b, subs_l
+ mask_pre = torch.where(pre_id == self.bos_pre, float('-inf'),
+ 0) # b, subs_l
+ mask = torch.concat([mask_next, mask_pre], 0) #2, 5
+ mask = torch.concat([mask_pad, mask], 1) # 2, 6
+ pred_token = next_pre
+ visual_f_i = visual_f[:2] # 2 l dim
+ for layer in self.cmff_decoder:
+ pred_token = layer(pred_token, prompt_char, visual_f_i,
+ mask.unsqueeze(1))
+ logits_next_i = self.ques1_head(self.norm_pred(pred_token))
+ logits = F.softmax(logits_next_i, -1)
+ pred_id_i = logits.argmax(-1) #2, 1
+ # print(pred_id_i.shape)
+
+ next_pred_id_list[:, j:j + 1] = pred_id_i[:1]
+ pre_pred_id_list[:, j:j + 1] = pred_id_i[1:2]
+ if not (next_pred_id_list == self.eos).any(dim=-1).all():
+ next_logits_all.append(logits[:1])
+ next_id = torch.concat([next_id[:, 1:], pred_id_i[:1]], 1)
+ if not (pre_pred_id_list == self.eos).any(dim=-1).all():
+ pre_logits_all.append(logits[1:2])
+ pre_id = torch.concat([pred_id_i[1:2], pre_id[:, :-1]], 1)
+
+ if (next_pred_id_list == self.eos).any(dim=-1).all() and (
+ pre_pred_id_list == self.eos).any(dim=-1).all():
+ break
+ # print(next_id, pre_id)
+ # exit(0)
+ if len(next_logits_all) > self.sub_str_len and len(
+ pre_logits_all) > self.sub_str_len:
+ next_logits_all_ = torch.concat(next_logits_all[:-1],
+ 1) # 1, l
+ pre_logits_all_ = torch.concat(pre_logits_all[:-1][::-1],
+ 1) #1, l
+
+ next_id = next_logits_all_.argmax(-1)[:, -self.sub_str_len:]
+ pre_id = pre_logits_all_.argmax(-1)[:, :self.sub_str_len]
+ next_logits_all = []
+ ques_next = self.next_token.tile([1, 1, 1, 1]).squeeze(1)
+ mask_pad = torch.zeros([1, 1],
+ dtype=torch.float32,
+ device=x.get_device())
+ for j in range(0, min(70, self.max_len - 1)):
+
+ prompt_next = torch.concat([
+ prompt_next_embed[:, :1, :],
+ prompt_next_embed[:, 1:, :] + self.char_embed(next_id)
+ ], 1) # b, sub_l, dim
+ mask_next = torch.where(next_id == self.bos_next,
+ float('-inf'), 0) # b, subs_l
+ mask = torch.concat([mask_pad, mask_next], 1)
+ # prompt_next = self.char_embed(prompt_id)
+ ques_next_i = ques_next
+ visual_f_i = visual_f[2:3]
+ for layer in self.cmff_decoder:
+ ques_next_i = layer(ques_next_i, prompt_next,
+ visual_f_i, mask.unsqueeze(1))
+ logits_next_i = self.ques1_head(
+ self.norm_pred(ques_next_i))
+ logits = F.softmax(logits_next_i, -1)
+ pred_id_i = logits.argmax(-1)
+ next_logits_all.append(logits)
+ next_id = torch.concat([next_id[:, 1:, ], pred_id_i], 1)
+ if next_id.equal(pre_id):
+ break
+ next_logits_all = torch.concat(next_logits_all, 1)
+ next_logits_all_ = torch.concat(
+ [next_logits_all_, next_logits_all], 1)
+
+ return torch.concat(
+ [next_logits_all_, pre_logits_all_[:, self.sub_str_len:]],
+ 1)
+ else:
+ return torch.concat(next_logits_all + pre_logits_all[::-1], 1)
+
+
+ def forward_test_bi_attn(self, x):
+ self.attn_maps = []
+ if not self.ds:
+ visual_f = x + self.vis_pos_embed
+ elif self.pos2d:
+ visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
+ visual_f = x.flatten(2).transpose(1, 2)
+ else:
+ visual_f = x
+ bs = 2
+ if 1:
+ next = self.next_token
+ pre = self.pre_token
+ next_pre = torch.concat([next, pre], 0)
+ next_pre = next_pre.squeeze(1) #2, 1, dim
+
+ prompt_next_embed = self.prompt_next_embed.squeeze(1)
+ prompt_pre_embed = self.prompt_pre_embed.squeeze(1)
+
+ next_id = torch.full([1, self.sub_str_len], self.bos_next, dtype=torch.long, device=x.get_device())
+ pre_id = torch.full([1, self.sub_str_len], self.bos_pre, dtype=torch.long, device=x.get_device())
+ # prompt_next_bos = self.char_embed(prompt_id)
+ # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.get_device())
+ next_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, device=x.get_device())
+ pre_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, device=x.get_device())
+ next_logits_all = []
+ pre_logits_all = []
+ attn_map_next = []
+ attn_map_pre = []
+ mask_pad = torch.zeros([bs, 1], dtype=torch.float32, device=x.get_device())
+ for j in range(0, min(70, self.max_len-1)):
+
+ prompt_char_next = torch.concat([prompt_next_embed[:, :1, :], prompt_next_embed[:, 1:, :] + self.char_embed(next_id)], 1) # b, sub_l, dim
+ prompt_char_pre = torch.concat([prompt_pre_embed[:, :1, :], prompt_pre_embed[:, 1:, :] + self.char_embed(pre_id)], 1) # b, sub_l, dim
+ prompt_char = torch.concat([prompt_char_next, prompt_char_pre], 0) #2, 6, dim
+ # prompt_char = prompt_char.flatten(0, 1)
+
+ mask_next = torch.where(next_id == self.bos_next, float('-inf'), 0) # b, subs_l
+ mask_pre = torch.where(pre_id == self.bos_pre, float('-inf'), 0) # b, subs_l
+ mask = torch.concat([mask_next, mask_pre], 0) #2, 5
+ mask = torch.concat([mask_pad, mask], 1) # 2, 6
+ pred_token = next_pre
+ visual_f_i = visual_f[:2] # 2 l dim
+ for layer in self.cmff_decoder:
+ pred_token = layer(pred_token, prompt_char, visual_f_i, mask.unsqueeze(1))
+
+
+ logits_next_i = self.ques1_head(self.norm_pred(pred_token))
+ logits = F.softmax(logits_next_i, -1)
+ pred_id_i = logits.argmax(-1) #2, 1
+ # print(pred_id_i.shape)
+
+ next_pred_id_list[:, j:j+1] = pred_id_i[:1]
+ pre_pred_id_list[:, j:j+1] = pred_id_i[1:2]
+ if not (next_pred_id_list == self.eos).any(dim=-1).all():
+ next_logits_all.append(logits[:1])
+ attn_map_next.append(self.cmff_decoder[-1].question_to_images_cross_attn.attn_map[0])
+ next_id = torch.concat([next_id[:, 1:], pred_id_i[:1]], 1)
+ if not (pre_pred_id_list == self.eos).any(dim=-1).all():
+ pre_logits_all.append(logits[1:2])
+ attn_map_pre.append(self.cmff_decoder[-1].question_to_images_cross_attn.attn_map[1])
+ pre_id = torch.concat([pred_id_i[1:2], pre_id[:, :-1]], 1)
+
+ if (next_pred_id_list == self.eos).any(dim=-1).all() and (pre_pred_id_list == self.eos).any(dim=-1).all():
+ break
+ # print(next_id, pre_id)
+ # exit(0)
+ if len(next_logits_all) > self.sub_str_len and len(pre_logits_all) > self.sub_str_len:
+ next_logits_all_ = torch.concat(next_logits_all[:-1], 1) # 1, l
+ pre_logits_all_ = torch.concat(pre_logits_all[:-1][::-1], 1) #1, l
+
+ next_id = next_logits_all_.argmax(-1)[:, -self.sub_str_len:]
+ pre_id = pre_logits_all_.argmax(-1)[:, :self.sub_str_len]
+ next_logits_all_mid = []
+ attn_map_next_mid = []
+ ques_next = self.next_token.tile([1, 1, 1, 1]).squeeze(1)
+ mask_pad = torch.zeros([1, 1], dtype=torch.float32, device=x.get_device())
+ for j in range(0, min(70, self.max_len-1)):
+
+ prompt_next = torch.concat([prompt_next_embed[:, :1, :], prompt_next_embed[:, 1:, :] + self.char_embed(next_id)], 1) # b, sub_l, dim
+ mask_next = torch.where(next_id == self.bos_next, float('-inf'), 0) # b, subs_l
+ mask = torch.concat([mask_pad, mask_next], 1)
+ # prompt_next = self.char_embed(prompt_id)
+ ques_next_i = ques_next
+ visual_f_i = visual_f[2:3]
+ for layer in self.cmff_decoder:
+ ques_next_i = layer(ques_next_i, prompt_next, visual_f_i, mask.unsqueeze(1))
+ logits_next_i = self.ques1_head(self.norm_pred(ques_next_i))
+ attn_map_next_mid.append(self.cmff_decoder[-1].question_to_images_cross_attn.attn_map[0])
+ logits = F.softmax(logits_next_i, -1)
+ pred_id_i = logits.argmax(-1)
+ next_logits_all_mid.append(logits)
+ next_id = torch.concat([next_id[:, 1:, ], pred_id_i], 1)
+ if next_id.equal(pre_id):
+ break
+ next_logits_all_mid = torch.concat(next_logits_all_mid, 1)
+ # next_logits_all_ = torch.concat([next_logits_all_, next_logits_all], 1)
+ self.attn_maps = [attn_map_next, attn_map_next_mid, attn_map_pre[::-1]]
+ return [torch.concat(next_logits_all, 1), next_logits_all_mid, torch.concat(pre_logits_all[::-1], 1)]
+ else:
+ self.attn_maps = [attn_map_next, attn_map_pre[::-1]]
+ return [torch.concat(next_logits_all, 1), torch.concat(pre_logits_all[::-1], 1)]
+
+
+ def forward_test(self, x):
+ self.attn_maps = []
+ if not self.ds:
+ visual_f = x + self.vis_pos_embed
+ elif self.pos2d:
+ visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
+ visual_f = x.flatten(2).transpose(1, 2)
+ else:
+ visual_f = x
+ bs = x.shape[0]
+
+ if self.next_mode:
+ ques_next = self.next_token.tile([bs, 1, 1, 1]).squeeze(1)
+ prompt_next_embed = self.prompt_next_embed.tile([bs, 1, 1,
+ 1]).squeeze(1)
+ prompt_id = torch.full([bs, self.sub_str_len],
+ self.bos_next,
+ dtype=torch.long,
+ device=x.get_device())
+ pred_id_list = torch.full([bs, self.max_len],
+ self.ignore_index,
+ dtype=torch.long,
+ device=x.get_device())
+ logits_all = []
+ mask_pad = torch.zeros([bs, 1],
+ dtype=torch.float32,
+ device=x.get_device())
+ for j in range(0, self.max_len - 1):
+
+ prompt_next = torch.concat([
+ prompt_next_embed[:, :1, :],
+ prompt_next_embed[:, 1:, :] + self.char_embed(prompt_id)
+ ], 1) # b, sub_l, dim
+ mask_next = torch.where(prompt_id == self.bos_next,
+ float('-inf'), 0) # b, subs_l
+ mask = torch.concat([mask_pad, mask_next], 1)
+ ques_next_i = ques_next
+ visual_f_i = visual_f
+ for layer in self.cmff_decoder:
+ ques_next_i = layer(ques_next_i, prompt_next, visual_f_i,
+ mask.unsqueeze(1))
+ self.attn_maps.append(
+ self.cmff_decoder[-1].question_to_images_cross_attn.
+ attn_map[0])
+ logits_next_i = self.ques1_head(self.norm_pred(ques_next_i))
+ logits = F.softmax(logits_next_i, -1)
+ pred_id_i = logits.argmax(-1)
+ logits_all.append(logits)
+ pred_id_list[:, j:j + 1] = pred_id_i
+ if (pred_id_list == self.eos).any(dim=-1).all():
+ break
+ prompt_id = torch.concat(
+ [
+ prompt_id[:, 1:, ],
+ pred_id_i,
+ ],
+ 1,
+ )
+ return torch.concat(logits_all, 1)
+ else:
+ ques_next = self.pre_token.tile([bs, 1, 1, 1]).squeeze(1)
+ prompt_pre_embed = self.prompt_pre_embed.tile([bs, 1, 1,
+ 1]).squeeze(1)
+ prompt_id = torch.full([bs, self.sub_str_len],
+ self.bos_pre,
+ dtype=torch.long,
+ device=x.get_device())
+ pred_id_list = torch.full([bs, self.max_len],
+ self.ignore_index,
+ dtype=torch.long,
+ device=x.get_device())
+ logits_all = []
+ mask_pad = torch.zeros([bs, 1],
+ dtype=torch.float32,
+ device=x.get_device())
+ for j in range(0, self.max_len - 1):
+
+ prompt_next = torch.concat([
+ prompt_pre_embed[:, :1, :],
+ prompt_pre_embed[:, 1:, :] + self.char_embed(prompt_id)
+ ], 1) # b, sub_l, dim
+ mask_next = torch.where(prompt_id == self.bos_pre,
+ float('-inf'), 0) # b, subs_l
+ mask = torch.concat([mask_pad, mask_next], 1)
+ ques_next_i = ques_next
+ visual_f_i = visual_f
+ for layer in self.cmff_decoder:
+ ques_next_i = layer(ques_next_i, prompt_next, visual_f_i,
+ mask.unsqueeze(1))
+ logits_next_i = self.ques1_head(self.norm_pred(ques_next_i))
+ logits = F.softmax(logits_next_i, -1)
+ pred_id_i = logits.argmax(-1)
+ logits_all.append(logits)
+ pred_id_list[:, j:j + 1] = pred_id_i
+ if (pred_id_list == self.eos).any(dim=-1).all():
+ break
+ prompt_id = torch.concat(
+ [
+ pred_id_i,
+ prompt_id[:, :-1, ],
+ ],
+ 1,
+ )
+ return torch.concat(logits_all, 1)
+
+ def forward_train(self, x, targets=None):
+ bs = x.shape[0]
+
+ if not self.ds:
+ visual_f = x + self.vis_pos_embed
+ elif self.pos2d:
+ visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
+ else:
+ visual_f = x
+ max_len_curr = targets[3].max()
+ subs = targets[1][:, :max_len_curr, :] # b, n, subs_l
+ mask_next = torch.where(subs == self.bos_next, float('-inf'),
+ 0) # b, n, subs_l
+ prompt_next_embed = self.prompt_next_embed.tile(
+ [bs, max_len_curr, 1, 1])
+ prompt_char_next = torch.concat([
+ prompt_next_embed[:, :, :1, :],
+ prompt_next_embed[:, :, 1:, :] + self.char_embed(subs)
+ ], 2) # b, n, sub_l, dim
+ next = self.next_token.tile([bs, max_len_curr, 1, 1])
+
+ max_len_curr_pre = targets[6].max()
+ subs = targets[4][:, :max_len_curr_pre, :] # b, n, subs_l
+ mask_pre = torch.where(subs == self.bos_pre, float('-inf'),
+ 0) # b, n, subs_l
+ prompt_pre_embed = self.prompt_pre_embed.tile(
+ [bs, max_len_curr_pre, 1, 1])
+ prompt_char_pre = torch.concat([
+ prompt_pre_embed[:, :, :1, :],
+ prompt_pre_embed[:, :, 1:, :] + self.char_embed(subs)
+ ], 2) # b, n, sub_l, dim
+ pre = self.pre_token.tile([bs, max_len_curr_pre, 1, 1]) # b, n, 1, dim
+
+ prompt_char = torch.concat([prompt_char_next, prompt_char_pre], 1)
+ next_pre = torch.concat([next, pre], 1)
+
+ mask_pad = torch.zeros([bs * (max_len_curr + max_len_curr_pre), 1],
+ dtype=torch.float32,
+ device=x.get_device())
+ mask = torch.concat([mask_next, mask_pre], 1).flatten(0, 1)
+ mask = torch.concat([mask_pad, mask], 1)
+ next_pre = next_pre.flatten(0, 1)
+ prompt_char = prompt_char.flatten(0, 1)
+ for layer in self.cmff_decoder:
+ next_pre = layer(next_pre, prompt_char, visual_f,
+ mask.unsqueeze(1))
+ answer1_pred = self.ques1_head(self.norm_pred(next_pre))
+ logits = answer1_pred[:, :max_len_curr]
+
+ label = torch.concat(
+ [targets[2][:, :max_len_curr], targets[5][:, :max_len_curr_pre]],
+ 1)
+ loss1 = F.cross_entropy(answer1_pred.flatten(0, 1),
+ label.flatten(0, 1),
+ ignore_index=self.ignore_index,
+ reduction='mean')
+ loss = {'loss': loss1}
+ return [loss, logits]
diff --git a/openrec/modeling/decoders/smtr_decoder_nattn.py b/openrec/modeling/decoders/smtr_decoder_nattn.py
new file mode 100644
index 0000000000000000000000000000000000000000..84471730999bdbadd983a7fa69004470e1f68648
--- /dev/null
+++ b/openrec/modeling/decoders/smtr_decoder_nattn.py
@@ -0,0 +1,521 @@
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.init import ones_, trunc_normal_, zeros_
+
+from openrec.modeling.common import DropPath, Identity
+from openrec.modeling.decoders.cppd_decoder import DecoderLayer
+from openrec.modeling.decoders.nrtr_decoder import Embeddings
+
+
+class CrossAttention(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, q, kv, key_mask=None):
+ N, C = kv.shape[1:]
+ QN = q.shape[1]
+ q = self.q(q).reshape([-1, QN, self.num_heads,
+ C // self.num_heads]).transpose(1, 2)
+ q = q * self.scale
+ k, v = self.kv(kv).reshape(
+ [-1, N, 2, self.num_heads,
+ C // self.num_heads]).permute(2, 0, 3, 1, 4)
+
+ attn = q.matmul(k.transpose(2, 3))
+
+ if key_mask is not None:
+ attn = attn + key_mask.unsqueeze(1)
+
+ attn = F.softmax(attn, -1)
+ if not self.training:
+ self.attn_map = attn
+ attn = self.attn_drop(attn)
+
+ x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C))
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SSMatchLayer(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ nextq2subs_head2=None,
+ dynq2img_heads=2,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ num_layer=2,
+ epsilon=1e-6,
+ ):
+ super().__init__()
+ self.dim = dim
+ if nextq2subs_head2 is None:
+ nextq2subs_head2 = dim // 32
+ self.normq1 = nn.LayerNorm(dim, eps=epsilon)
+ self.normkv1 = nn.LayerNorm(dim, eps=epsilon)
+ self.images_to_question_cross_attn = CrossAttention(
+ dim,
+ num_heads=nextq2subs_head2,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop)
+ self.normq2 = nn.LayerNorm(dim, eps=epsilon)
+ # self.normkv2 = nn.LayerNorm(dim, eps=epsilon)
+ dpr = np.linspace(0, drop_path, num_layer)
+ self.question_to_images_cross_attn = nn.ModuleList([
+ DecoderLayer(
+ dim=dim,
+ num_heads=dynq2img_heads,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=dpr[i],
+ act_layer=act_layer,
+ ) for i in range(num_layer)
+ ])
+ # CrossAttention(
+ # dim,
+ # num_heads=dynq2img_heads,
+ # qkv_bias=qkv_bias,
+ # qk_scale=qk_scale,
+ # attn_drop=attn_drop,
+ # proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+
+ def forward(self, question_f, prompt_f, visual_f, mask=None):
+
+ question_f = question_f + self.drop_path(
+ self.images_to_question_cross_attn(self.normq1(question_f),
+ self.normkv1(prompt_f), mask))
+ question_f = question_f.reshape(visual_f.shape[0], -1, self.dim)
+ question_f = self.normq2(question_f)
+ # kv = self.normkv2(visual_f)
+ for layer in self.question_to_images_cross_attn:
+ question_f = layer(question_f, visual_f)
+
+ return question_f
+
+
+class SMTRDecoderNumAttn(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_layer=2,
+ nextq2subs_head2=None,
+ dynq2img_heads=2,
+ drop_path_rate=0.1,
+ max_len=25,
+ vis_seq=50,
+ ds=False,
+ pos2d=False,
+ max_size=[8, 32],
+ sub_str_len=5,
+ next_mode=True,
+ infer_aug=False,
+ **kwargs):
+ super(SMTRDecoderNumAttn, self).__init__()
+
+ self.out_channels = out_channels
+ dim = in_channels
+ self.dim = dim
+ self.max_len = max_len + 3 # max_len + eos + bos
+ self.char_embed = Embeddings(d_model=dim,
+ vocab=self.out_channels,
+ scale_embedding=True)
+ self.ignore_index = out_channels - 1
+ self.sub_str_len = sub_str_len
+ self.bos_next = out_channels - 3
+ self.bos_pre = out_channels - 2
+ self.eos = 0
+ self.next_mode = next_mode
+ self.infer_aug = infer_aug
+ self.cmff_decoder = SSMatchLayer(dim=dim,
+ nextq2subs_head2=nextq2subs_head2,
+ dynq2img_heads=dynq2img_heads,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=drop_path_rate,
+ num_layer=num_layer)
+
+ self.ds = ds
+ self.pos2d = pos2d
+ if not ds:
+ self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim],
+ dtype=torch.float32),
+ requires_grad=True)
+ trunc_normal_(self.vis_pos_embed, std=0.02)
+ elif pos2d:
+ pos_embed = torch.zeros([1, max_size[0] * max_size[1], dim],
+ dtype=torch.float32)
+ trunc_normal_(pos_embed, mean=0, std=0.02)
+ self.vis_pos_embed = nn.Parameter(pos_embed.transpose(
+ 1, 2).reshape(1, dim, max_size[0], max_size[1]),
+ requires_grad=True)
+
+ self.next_token = nn.Parameter(torch.zeros([1, 1, 1, dim],
+ dtype=torch.float32),
+ requires_grad=True)
+
+ self.pre_token = nn.Parameter(torch.zeros([1, 1, 1, dim],
+ dtype=torch.float32),
+ requires_grad=True)
+
+ self.prompt_next_embed = nn.Parameter(torch.zeros(
+ [1, 1, self.sub_str_len + 1, dim], dtype=torch.float32),
+ requires_grad=True)
+
+ self.prompt_pre_embed = nn.Parameter(torch.zeros(
+ [1, 1, self.sub_str_len + 1, dim], dtype=torch.float32),
+ requires_grad=True)
+
+ self.norm_pred = nn.LayerNorm(dim, eps=1e-6)
+ self.ques1_head = nn.Linear(dim, self.out_channels - 3)
+
+ trunc_normal_(self.next_token, std=0.02)
+ trunc_normal_(self.pre_token, std=0.02)
+ trunc_normal_(self.prompt_pre_embed, std=0.02)
+ trunc_normal_(self.prompt_next_embed, std=0.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'vis_pos_embed', 'pre_token', 'next_token', 'char_embed'}
+
+ def forward(self, x, data=None):
+ if self.training:
+ return self.forward_train(x, data)
+ else:
+ if self.infer_aug:
+ return self.forward_test_bi(x)
+ return self.forward_test(x)
+
+ def forward_test_bi(self, x):
+ # self.attn_maps = []
+ if not self.ds:
+ visual_f = x + self.vis_pos_embed
+ elif self.pos2d:
+ visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
+ visual_f = x.flatten(2).transpose(1, 2)
+ else:
+ visual_f = x
+ bs = 2
+ if 1:
+ next = self.next_token
+ pre = self.pre_token
+ next_pre = torch.concat([next, pre], 0)
+ next_pre = next_pre.squeeze(1) #2, 1, dim
+
+ prompt_next_embed = self.prompt_next_embed.squeeze(1)
+ prompt_pre_embed = self.prompt_pre_embed.squeeze(1)
+
+ next_id = torch.full([1, self.sub_str_len],
+ self.bos_next,
+ dtype=torch.long,
+ device=x.get_device())
+ pre_id = torch.full([1, self.sub_str_len],
+ self.bos_pre,
+ dtype=torch.long,
+ device=x.get_device())
+ # prompt_next_bos = self.char_embed(prompt_id)
+ # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.get_device())
+ next_pred_id_list = torch.full([1, self.max_len],
+ self.ignore_index,
+ dtype=torch.long,
+ device=x.get_device())
+ pre_pred_id_list = torch.full([1, self.max_len],
+ self.ignore_index,
+ dtype=torch.long,
+ device=x.get_device())
+ next_logits_all = []
+ pre_logits_all = []
+ mask_pad = torch.zeros([bs, 1],
+ dtype=torch.float32,
+ device=x.get_device())
+ for j in range(0, min(70, self.max_len - 1)):
+
+ prompt_char_next = torch.concat([
+ prompt_next_embed[:, :1, :],
+ prompt_next_embed[:, 1:, :] + self.char_embed(next_id)
+ ], 1) # b, sub_l, dim
+ prompt_char_pre = torch.concat([
+ prompt_pre_embed[:, :1, :],
+ prompt_pre_embed[:, 1:, :] + self.char_embed(pre_id)
+ ], 1) # b, sub_l, dim
+ prompt_char = torch.concat([prompt_char_next, prompt_char_pre],
+ 0) #2, 6, dim
+ # prompt_char = prompt_char.flatten(0, 1)
+
+ mask_next = torch.where(next_id == self.bos_next,
+ float('-inf'), 0) # b, subs_l
+ mask_pre = torch.where(pre_id == self.bos_pre, float('-inf'),
+ 0) # b, subs_l
+ mask = torch.concat([mask_next, mask_pre], 0) #2, 5
+ mask = torch.concat([mask_pad, mask], 1) # 2, 6
+ pred_token = next_pre
+ visual_f_i = visual_f[:2] # 2 l dim
+ pred_token = self.cmff_decoder(pred_token, prompt_char,
+ visual_f_i, mask.unsqueeze(1))
+ logits_next_i = self.ques1_head(self.norm_pred(pred_token))
+ logits = F.softmax(logits_next_i, -1)
+ pred_id_i = logits.argmax(-1) #2, 1
+ # print(pred_id_i.shape)
+
+ next_pred_id_list[:, j:j + 1] = pred_id_i[:1]
+ pre_pred_id_list[:, j:j + 1] = pred_id_i[1:2]
+ if not (next_pred_id_list == self.eos).any(dim=-1).all():
+ next_logits_all.append(logits[:1])
+ next_id = torch.concat([next_id[:, 1:], pred_id_i[:1]], 1)
+ if not (pre_pred_id_list == self.eos).any(dim=-1).all():
+ pre_logits_all.append(logits[1:2])
+ pre_id = torch.concat([pred_id_i[1:2], pre_id[:, :-1]], 1)
+
+ if (next_pred_id_list == self.eos).any(dim=-1).all() and (
+ pre_pred_id_list == self.eos).any(dim=-1).all():
+ break
+ # print(next_id, pre_id)
+ # exit(0)
+ if len(next_logits_all) > self.sub_str_len and len(
+ pre_logits_all) > self.sub_str_len:
+ next_logits_all_ = torch.concat(next_logits_all[:-1],
+ 1) # 1, l
+ pre_logits_all_ = torch.concat(pre_logits_all[:-1][::-1],
+ 1) #1, l
+
+ next_id = next_logits_all_.argmax(-1)[:, -self.sub_str_len:]
+ pre_id = pre_logits_all_.argmax(-1)[:, :self.sub_str_len]
+ next_logits_all = []
+ ques_next = self.next_token.tile([1, 1, 1, 1]).squeeze(1)
+ mask_pad = torch.zeros([1, 1],
+ dtype=torch.float32,
+ device=x.get_device())
+ for j in range(0, min(70, self.max_len - 1)):
+
+ prompt_next = torch.concat([
+ prompt_next_embed[:, :1, :],
+ prompt_next_embed[:, 1:, :] + self.char_embed(next_id)
+ ], 1) # b, sub_l, dim
+ mask_next = torch.where(next_id == self.bos_next,
+ float('-inf'), 0) # b, subs_l
+ mask = torch.concat([mask_pad, mask_next], 1)
+ # prompt_next = self.char_embed(prompt_id)
+ ques_next_i = ques_next
+ visual_f_i = visual_f[2:3]
+ ques_next_i = self.cmff_decoder(ques_next_i, prompt_next,
+ visual_f_i,
+ mask.unsqueeze(1))
+ logits_next_i = self.ques1_head(
+ self.norm_pred(ques_next_i))
+ logits = F.softmax(logits_next_i, -1)
+ pred_id_i = logits.argmax(-1)
+ next_logits_all.append(logits)
+ next_id = torch.concat([next_id[:, 1:, ], pred_id_i], 1)
+ if next_id.equal(pre_id):
+ break
+ next_logits_all = torch.concat(next_logits_all, 1)
+ next_logits_all_ = torch.concat(
+ [next_logits_all_, next_logits_all], 1)
+
+ return torch.concat(
+ [next_logits_all_, pre_logits_all_[:, self.sub_str_len:]],
+ 1)
+ else:
+ return torch.concat(next_logits_all + pre_logits_all[::-1], 1)
+
+ def forward_test(self, x):
+ # self.attn_maps = []
+ if not self.ds:
+ visual_f = x + self.vis_pos_embed
+ elif self.pos2d:
+ visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
+ visual_f = x.flatten(2).transpose(1, 2)
+ else:
+ visual_f = x
+ bs = x.shape[0]
+
+ if self.next_mode:
+ ques_next = self.next_token.tile([bs, 1, 1, 1]).squeeze(1)
+ prompt_next_embed = self.prompt_next_embed.tile([bs, 1, 1,
+ 1]).squeeze(1)
+ prompt_id = torch.full([bs, self.sub_str_len],
+ self.bos_next,
+ dtype=torch.long,
+ device=x.get_device())
+ pred_id_list = torch.full([bs, self.max_len],
+ self.ignore_index,
+ dtype=torch.long,
+ device=x.get_device())
+ logits_all = []
+ mask_pad = torch.zeros([bs, 1],
+ dtype=torch.float32,
+ device=x.get_device())
+ for j in range(0, self.max_len - 1):
+
+ prompt_next = torch.concat([
+ prompt_next_embed[:, :1, :],
+ prompt_next_embed[:, 1:, :] + self.char_embed(prompt_id)
+ ], 1) # b, sub_l, dim
+ mask_next = torch.where(prompt_id == self.bos_next,
+ float('-inf'), 0) # b, subs_l
+ mask = torch.concat([mask_pad, mask_next], 1)
+ ques_next_i = ques_next
+ visual_f_i = visual_f
+ ques_next_i = self.cmff_decoder(ques_next_i, prompt_next,
+ visual_f_i, mask.unsqueeze(1))
+ # self.attn_maps.append(
+ # self.cmff_decoder[-1].question_to_images_cross_attn.
+ # attn_map[0])
+ logits_next_i = self.ques1_head(self.norm_pred(ques_next_i))
+ logits = F.softmax(logits_next_i, -1)
+ pred_id_i = logits.argmax(-1)
+ logits_all.append(logits)
+ pred_id_list[:, j:j + 1] = pred_id_i
+ if (pred_id_list == self.eos).any(dim=-1).all():
+ break
+ prompt_id = torch.concat(
+ [
+ prompt_id[:, 1:, ],
+ pred_id_i,
+ ],
+ 1,
+ )
+ return torch.concat(logits_all, 1)
+ else:
+ ques_next = self.pre_token.tile([bs, 1, 1, 1]).squeeze(1)
+ prompt_pre_embed = self.prompt_pre_embed.tile([bs, 1, 1,
+ 1]).squeeze(1)
+ prompt_id = torch.full([bs, self.sub_str_len],
+ self.bos_pre,
+ dtype=torch.long,
+ device=x.get_device())
+ pred_id_list = torch.full([bs, self.max_len],
+ self.ignore_index,
+ dtype=torch.long,
+ device=x.get_device())
+ logits_all = []
+ mask_pad = torch.zeros([bs, 1],
+ dtype=torch.float32,
+ device=x.get_device())
+ for j in range(0, self.max_len - 1):
+
+ prompt_next = torch.concat([
+ prompt_pre_embed[:, :1, :],
+ prompt_pre_embed[:, 1:, :] + self.char_embed(prompt_id)
+ ], 1) # b, sub_l, dim
+ mask_next = torch.where(prompt_id == self.bos_pre,
+ float('-inf'), 0) # b, subs_l
+ mask = torch.concat([mask_pad, mask_next], 1)
+ ques_next_i = ques_next
+ visual_f_i = visual_f
+ ques_next_i = self.cmff_decoder(ques_next_i, prompt_next,
+ visual_f_i, mask.unsqueeze(1))
+ logits_next_i = self.ques1_head(self.norm_pred(ques_next_i))
+ logits = F.softmax(logits_next_i, -1)
+ pred_id_i = logits.argmax(-1)
+ logits_all.append(logits)
+ pred_id_list[:, j:j + 1] = pred_id_i
+ if (pred_id_list == self.eos).any(dim=-1).all():
+ break
+ prompt_id = torch.concat(
+ [
+ pred_id_i,
+ prompt_id[:, :-1, ],
+ ],
+ 1,
+ )
+ return torch.concat(logits_all, 1)
+
+ def forward_train(self, x, targets=None):
+ bs = x.shape[0]
+
+ if not self.ds:
+ visual_f = x + self.vis_pos_embed
+ elif self.pos2d:
+ visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
+ else:
+ visual_f = x
+ max_len_curr = targets[3].max()
+ subs = targets[1][:, :max_len_curr, :] # b, n, subs_l
+ mask_next = torch.where(subs == self.bos_next, float('-inf'),
+ 0) # b, n, subs_l
+ prompt_next_embed = self.prompt_next_embed.tile(
+ [bs, max_len_curr, 1, 1])
+ prompt_char_next = torch.concat([
+ prompt_next_embed[:, :, :1, :],
+ prompt_next_embed[:, :, 1:, :] + self.char_embed(subs)
+ ], 2) # b, n, sub_l, dim
+ next = self.next_token.tile([bs, max_len_curr, 1, 1])
+
+ max_len_curr_pre = targets[6].max()
+ subs = targets[4][:, :max_len_curr_pre, :] # b, n, subs_l
+ mask_pre = torch.where(subs == self.bos_pre, float('-inf'),
+ 0) # b, n, subs_l
+ prompt_pre_embed = self.prompt_pre_embed.tile(
+ [bs, max_len_curr_pre, 1, 1])
+ prompt_char_pre = torch.concat([
+ prompt_pre_embed[:, :, :1, :],
+ prompt_pre_embed[:, :, 1:, :] + self.char_embed(subs)
+ ], 2) # b, n, sub_l, dim
+ pre = self.pre_token.tile([bs, max_len_curr_pre, 1, 1]) # b, n, 1, dim
+
+ prompt_char = torch.concat([prompt_char_next, prompt_char_pre], 1)
+ next_pre = torch.concat([next, pre], 1)
+
+ mask_pad = torch.zeros([bs * (max_len_curr + max_len_curr_pre), 1],
+ dtype=torch.float32,
+ device=x.get_device())
+ mask = torch.concat([mask_next, mask_pre], 1).flatten(0, 1)
+ mask = torch.concat([mask_pad, mask], 1)
+ next_pre = next_pre.flatten(0, 1)
+ prompt_char = prompt_char.flatten(0, 1)
+ next_pre = self.cmff_decoder(next_pre, prompt_char, visual_f,
+ mask.unsqueeze(1))
+ answer1_pred = self.ques1_head(self.norm_pred(next_pre))
+ logits = answer1_pred[:, :max_len_curr]
+
+ label = torch.concat(
+ [targets[2][:, :max_len_curr], targets[5][:, :max_len_curr_pre]],
+ 1)
+ loss1 = F.cross_entropy(answer1_pred.flatten(0, 1),
+ label.flatten(0, 1),
+ ignore_index=self.ignore_index,
+ reduction='mean')
+ loss = {'loss': loss1}
+ return [loss, logits]
diff --git a/openrec/modeling/decoders/srn_decoder.py b/openrec/modeling/decoders/srn_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3e93666ce7acfcc88772f2478d1abe9f2c5e83e
--- /dev/null
+++ b/openrec/modeling/decoders/srn_decoder.py
@@ -0,0 +1,283 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .nrtr_decoder import Embeddings, TransformerBlock
+
+
+class PVAM(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ char_num,
+ max_text_length,
+ num_heads,
+ hidden_dims,
+ dropout_rate=0):
+ super(PVAM, self).__init__()
+ self.char_num = char_num
+ self.max_length = max_text_length
+ self.num_heads = num_heads
+ self.hidden_dims = hidden_dims
+ self.dropout_rate = dropout_rate
+ #TODO
+ self.emb = nn.Embedding(num_embeddings=256,
+ embedding_dim=hidden_dims,
+ sparse=False)
+ self.drop_out = nn.Dropout(dropout_rate)
+ self.feat_emb = nn.Linear(in_channels, in_channels)
+ self.token_emb = nn.Embedding(max_text_length, in_channels)
+ self.score = nn.Linear(in_channels, 1, bias=False)
+
+ def feat_pos_mix(self, conv_features, encoder_word_pos, dropout_rate):
+ #b h*w c
+ pos_emb = self.emb(encoder_word_pos)
+ # pos_emb = pos_emb.detach()
+ enc_input = conv_features + pos_emb
+
+ if dropout_rate:
+ enc_input = self.drop_out(enc_input)
+
+ return enc_input
+
+ def forward(self, inputs):
+ b, c, h, w = inputs.shape
+ conv_features = inputs.view(-1, c, h * w)
+ conv_features = conv_features.permute(0, 2, 1).contiguous()
+ # b h*w c
+
+ # transformer encoder
+ b, t, c = conv_features.shape
+
+ encoder_feat_pos = torch.arange(t, dtype=torch.long).to(inputs.device)
+
+ enc_inputs = self.feat_pos_mix(conv_features, encoder_feat_pos,
+ self.dropout_rate)
+
+ inputs = self.feat_emb(enc_inputs) # feat emb
+
+ inputs = inputs.unsqueeze(1).expand(-1, self.max_length, -1, -1)
+
+ # b maxlen h*w c
+
+ tokens_pos = torch.arange(self.max_length,
+ dtype=torch.long).to(inputs.device)
+ tokens_pos = tokens_pos.unsqueeze(0).expand(b, -1)
+
+ tokens_pos_emd = self.token_emb(tokens_pos)
+ tokens_pos_emd = tokens_pos_emd.unsqueeze(2).expand(-1, -1, t, -1)
+ # b maxlen h*w c
+
+ attention_weight = torch.tanh(tokens_pos_emd + inputs)
+
+ attention_weight = torch.squeeze(self.score(attention_weight),
+ -1) #b,25,256
+
+ attention_weight = F.softmax(attention_weight, dim=-1) #b,25,256
+
+ pvam_features = torch.matmul(attention_weight, enc_inputs)
+
+ return pvam_features
+
+
+class GSRM(nn.Module):
+
+ def __init__(self,
+ in_channel,
+ char_num,
+ max_len,
+ num_heads,
+ hidden_dims,
+ num_layers,
+ dropout_rate=0,
+ attention_dropout=0.1):
+ super(GSRM, self).__init__()
+ self.char_num = char_num
+ self.max_len = max_len
+ self.num_heads = num_heads
+
+ self.cls_op = nn.Linear(in_channel, self.char_num)
+ self.cls_final = nn.Linear(in_channel, self.char_num)
+
+ self.word_emb = Embeddings(d_model=hidden_dims, vocab=char_num)
+ self.pos_emb = nn.Embedding(char_num, hidden_dims)
+ self.dropout_rate = dropout_rate
+ self.emb_drop_out = nn.Dropout(dropout_rate)
+
+ self.forward_self_attn = nn.ModuleList([
+ TransformerBlock(
+ d_model=hidden_dims,
+ nhead=num_heads,
+ attention_dropout_rate=attention_dropout,
+ residual_dropout_rate=0.1,
+ dim_feedforward=hidden_dims,
+ with_self_attn=True,
+ with_cross_attn=False,
+ ) for i in range(num_layers)
+ ])
+
+ self.backward_self_attn = nn.ModuleList([
+ TransformerBlock(
+ d_model=hidden_dims,
+ nhead=num_heads,
+ attention_dropout_rate=attention_dropout,
+ residual_dropout_rate=0.1,
+ dim_feedforward=hidden_dims,
+ with_self_attn=True,
+ with_cross_attn=False,
+ ) for i in range(num_layers)
+ ])
+
+ def _pos_emb(self, word_seq, pos, dropoutrate):
+ """
+ word_Seq: bsz len
+ pos: bsz len
+ """
+ word_emb_seq = self.word_emb(word_seq)
+ pos_emb_seq = self.pos_emb(pos)
+ # pos_emb_seq = pos_emb_seq.detach()
+
+ input_mix = word_emb_seq + pos_emb_seq
+ if dropoutrate > 0:
+ input_mix = self.emb_drop_out(input_mix)
+
+ return input_mix
+
+ def forward(self, inputs):
+
+ bos_idx = self.char_num - 2
+ eos_idx = self.char_num - 1
+ b, t, c = inputs.size() #b,25,512
+ inputs = inputs.view(-1, c)
+ cls_res = self.cls_op(inputs) #b,25,n_class
+
+ word_pred_PVAM = F.softmax(cls_res, dim=-1).argmax(-1)
+ word_pred_PVAM = word_pred_PVAM.view(-1, t, 1)
+ #b 25 1
+ word1 = F.pad(word_pred_PVAM, [0, 0, 1, 0], 'constant', value=bos_idx)
+ word_forward = word1[:, :-1, :].squeeze(-1)
+
+ word_backward = word_pred_PVAM.squeeze(-1)
+
+ #mask
+ attn_mask_forward = torch.triu(
+ torch.full((self.max_len, self.max_len),
+ dtype=torch.float32,
+ fill_value=-torch.inf),
+ diagonal=1,
+ ).to(inputs.device)
+ attn_mask_forward = attn_mask_forward.unsqueeze(0).expand(
+ self.num_heads, -1, -1)
+ attn_mask_backward = torch.tril(
+ torch.full((self.max_len, self.max_len),
+ dtype=torch.float32,
+ fill_value=-torch.inf),
+ diagonal=-1,
+ ).to(inputs.device)
+ attn_mask_backward = attn_mask_backward.unsqueeze(0).expand(
+ self.num_heads, -1, -1)
+
+ #B,25
+
+ pos = torch.arange(self.max_len, dtype=torch.long).to(inputs.device)
+ pos = pos.unsqueeze(0).expand(b, -1) #b,25
+
+ word_front_mix = self._pos_emb(word_forward, pos, self.dropout_rate)
+ word_backward_mix = self._pos_emb(word_backward, pos,
+ self.dropout_rate)
+ # b 25 emb_dim
+
+ for attn_layer in self.forward_self_attn:
+ word_front_mix = attn_layer(word_front_mix,
+ self_mask=attn_mask_forward)
+
+ for attn_layer in self.backward_self_attn:
+ word_backward_mix = attn_layer(word_backward_mix,
+ self_mask=attn_mask_backward)
+
+ #b,25,emb_dim
+ eos_emd = self.word_emb(torch.full(
+ (1, ), eos_idx).to(inputs.device)).expand(b, 1, -1)
+ word_backward_mix = torch.cat((word_backward_mix, eos_emd), dim=1)
+ word_backward_mix = word_backward_mix[:, 1:, ]
+
+ gsrm_features = word_front_mix + word_backward_mix
+
+ gsrm_out = self.cls_final(gsrm_features)
+ # torch.matmul(gsrm_features,
+ # self.word_emb.embedding.weight.permute(1, 0))
+
+ b, t, c = gsrm_out.size()
+ #b,25,n_class
+ gsrm_out = gsrm_out.view(-1, c).contiguous()
+
+ return gsrm_features, cls_res, gsrm_out
+
+
+class VSFD(nn.Module):
+
+ def __init__(self, in_channels, out_channels):
+ super(VSFD, self).__init__()
+ self.char_num = out_channels
+ self.fc0 = nn.Linear(in_channels * 2, in_channels)
+ self.fc1 = nn.Linear(in_channels, self.char_num)
+
+ def forward(self, pvam_feature, gsrm_feature):
+ _, t, c1 = pvam_feature.size()
+ _, t, c2 = gsrm_feature.size()
+ combine_featurs = torch.cat([pvam_feature, gsrm_feature], dim=-1)
+ combine_featurs = combine_featurs.view(-1, c1 + c2).contiguous()
+ atten = self.fc0(combine_featurs)
+ atten = torch.sigmoid(atten)
+ atten = atten.view(-1, t, c1)
+ combine_featurs = atten * pvam_feature + (1 - atten) * gsrm_feature
+ combine_featurs = combine_featurs.view(-1, c1).contiguous()
+ out = self.fc1(combine_featurs)
+ return out
+
+
+class SRNDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ hidden_dims,
+ num_decoder_layers=4,
+ max_text_length=25,
+ num_heads=8,
+ **kwargs):
+ super(SRNDecoder, self).__init__()
+
+ self.max_text_length = max_text_length
+ self.num_heads = num_heads
+
+ self.pvam = PVAM(in_channels=in_channels,
+ char_num=out_channels,
+ max_text_length=max_text_length,
+ num_heads=num_heads,
+ hidden_dims=hidden_dims,
+ dropout_rate=0.1)
+
+ self.gsrm = GSRM(in_channel=in_channels,
+ char_num=out_channels,
+ max_len=max_text_length,
+ num_heads=num_heads,
+ num_layers=num_decoder_layers,
+ hidden_dims=hidden_dims)
+
+ self.vsfd = VSFD(in_channels=in_channels, out_channels=out_channels)
+
+ def forward(self, feat, data=None):
+ # feat [B,512,8,32]
+
+ pvam_feature = self.pvam(feat)
+
+ gsrm_features, pvam_preds, gsrm_preds = self.gsrm(pvam_feature)
+
+ vsfd_preds = self.vsfd(pvam_feature, gsrm_features)
+
+ if not self.training:
+ preds = F.softmax(vsfd_preds, dim=-1)
+ return preds
+
+ return [pvam_preds, gsrm_preds, vsfd_preds]
diff --git a/openrec/modeling/decoders/visionlan_decoder.py b/openrec/modeling/decoders/visionlan_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb0d6b987967302f3177aba426676ced5765fac0
--- /dev/null
+++ b/openrec/modeling/decoders/visionlan_decoder.py
@@ -0,0 +1,321 @@
+import torch
+import torch.nn as nn
+
+from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, TransformerBlock
+
+
+class Transformer_Encoder(nn.Module):
+
+ def __init__(
+ self,
+ n_layers=3,
+ n_head=8,
+ d_model=512,
+ d_inner=2048,
+ dropout=0.1,
+ n_position=256,
+ ):
+
+ super(Transformer_Encoder, self).__init__()
+ self.pe = PositionalEncoding(dropout=dropout,
+ dim=d_model,
+ max_len=n_position)
+ self.layer_stack = nn.ModuleList([
+ TransformerBlock(d_model, n_head, d_inner) for _ in range(n_layers)
+ ])
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
+
+ def forward(self, enc_output, src_mask):
+ enc_output = self.pe(enc_output) # position embeding
+ for enc_layer in self.layer_stack:
+ enc_output = enc_layer(enc_output, self_mask=src_mask)
+ enc_output = self.layer_norm(enc_output)
+ return enc_output
+
+
+class PP_layer(nn.Module):
+
+ def __init__(self, n_dim=512, N_max_character=25, n_position=256):
+ super(PP_layer, self).__init__()
+ self.character_len = N_max_character
+ self.f0_embedding = nn.Embedding(N_max_character, n_dim)
+ self.w0 = nn.Linear(N_max_character, n_position)
+ self.wv = nn.Linear(n_dim, n_dim)
+ self.we = nn.Linear(n_dim, N_max_character)
+ self.active = nn.Tanh()
+ self.softmax = nn.Softmax(dim=2)
+
+ def forward(self, enc_output):
+ reading_order = torch.arange(self.character_len,
+ dtype=torch.long,
+ device=enc_output.device)
+ reading_order = reading_order.unsqueeze(0).expand(
+ enc_output.shape[0], -1) # (S,) -> (B, S)
+ reading_order = self.f0_embedding(reading_order) # b,25,512
+ # calculate attention
+
+ t = self.w0(reading_order.transpose(1, 2)) # b,512,256
+ t = self.active(t.transpose(1, 2) + self.wv(enc_output)) # b,256,512
+ t = self.we(t) # b,256,25
+ t = self.softmax(t.transpose(1, 2)) # b,25,256
+ g_output = torch.bmm(t, enc_output) # b,25,512
+ return g_output
+
+
+class Prediction(nn.Module):
+
+ def __init__(
+ self,
+ n_dim=512,
+ n_class=37,
+ N_max_character=25,
+ n_position=256,
+ ):
+ super(Prediction, self).__init__()
+ self.pp = PP_layer(n_dim=n_dim,
+ N_max_character=N_max_character,
+ n_position=n_position)
+ self.pp_share = PP_layer(n_dim=n_dim,
+ N_max_character=N_max_character,
+ n_position=n_position)
+ self.w_vrm = nn.Linear(n_dim, n_class) # output layer
+ self.w_share = nn.Linear(n_dim, n_class) # output layer
+ self.nclass = n_class
+
+ def forward(self, cnn_feature, f_res, f_sub, is_Train=False, use_mlm=True):
+ if is_Train:
+ if not use_mlm:
+ g_output = self.pp(cnn_feature) # b,25,512
+ g_output = self.w_vrm(g_output)
+ f_res = 0
+ f_sub = 0
+ return g_output, f_res, f_sub
+ g_output = self.pp(cnn_feature) # b,25,512
+ f_res = self.pp_share(f_res)
+ f_sub = self.pp_share(f_sub)
+ g_output = self.w_vrm(g_output)
+ f_res = self.w_share(f_res)
+ f_sub = self.w_share(f_sub)
+ return g_output, f_res, f_sub
+ else:
+ g_output = self.pp(cnn_feature) # b,25,512
+ g_output = self.w_vrm(g_output)
+ return g_output
+
+
+class MLM(nn.Module):
+ """Architecture of MLM."""
+
+ def __init__(
+ self,
+ n_dim=512,
+ n_position=256,
+ n_head=8,
+ dim_feedforward=2048,
+ max_text_length=25,
+ ):
+ super(MLM, self).__init__()
+ self.MLM_SequenceModeling_mask = Transformer_Encoder(
+ n_layers=2,
+ n_head=n_head,
+ d_model=n_dim,
+ d_inner=dim_feedforward,
+ n_position=n_position,
+ )
+ self.MLM_SequenceModeling_WCL = Transformer_Encoder(
+ n_layers=1,
+ n_head=n_head,
+ d_model=n_dim,
+ d_inner=dim_feedforward,
+ n_position=n_position,
+ )
+ self.pos_embedding = nn.Embedding(max_text_length, n_dim)
+ self.w0_linear = nn.Linear(1, n_position)
+ self.wv = nn.Linear(n_dim, n_dim)
+ self.active = nn.Tanh()
+ self.we = nn.Linear(n_dim, 1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, input, label_pos):
+ # transformer unit for generating mask_c
+ feature_v_seq = self.MLM_SequenceModeling_mask(input, src_mask=None)
+ # position embedding layer
+ pos_emb = self.pos_embedding(label_pos.long())
+ pos_emb = self.w0_linear(torch.unsqueeze(pos_emb,
+ dim=2)).transpose(1, 2)
+ # fusion position embedding with features V & generate mask_c
+ att_map_sub = self.active(pos_emb + self.wv(feature_v_seq))
+ att_map_sub = self.we(att_map_sub) # b,256,1
+ att_map_sub = self.sigmoid(att_map_sub.transpose(1, 2)) # b,1,256
+ # WCL
+ # generate inputs for WCL
+ f_res = input * (1 - att_map_sub.transpose(1, 2)
+ ) # second path with remaining string
+ f_sub = input * (att_map_sub.transpose(1, 2)
+ ) # first path with occluded character
+ # transformer units in WCL
+ f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None)
+ f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None)
+ return f_res, f_sub, att_map_sub
+
+
+class MLM_VRM(nn.Module):
+
+ def __init__(
+ self,
+ n_layers=3,
+ n_position=256,
+ n_dim=512,
+ n_head=8,
+ dim_feedforward=2048,
+ max_text_length=25,
+ nclass=37,
+ ):
+ super(MLM_VRM, self).__init__()
+ self.MLM = MLM(
+ n_dim=n_dim,
+ n_position=n_position,
+ n_head=n_head,
+ dim_feedforward=dim_feedforward,
+ max_text_length=max_text_length,
+ )
+ self.SequenceModeling = Transformer_Encoder(
+ n_layers=n_layers,
+ n_head=n_head,
+ d_model=n_dim,
+ d_inner=dim_feedforward,
+ n_position=n_position,
+ )
+ self.Prediction = Prediction(
+ n_dim=n_dim,
+ n_position=n_position,
+ N_max_character=max_text_length + 1,
+ n_class=nclass,
+ ) # N_max_character = 1 eos + 25 characters
+ self.nclass = nclass
+ self.max_text_length = max_text_length
+
+ def forward(self, input, label_pos, training_step, is_Train=False):
+ nT = self.max_text_length
+
+ b, c, h, w = input.shape
+ input = input.reshape(b, c, -1)
+ input = input.transpose(1, 2)
+
+ if is_Train:
+ if training_step == 'LF_1':
+ f_res = 0
+ f_sub = 0
+ input = self.SequenceModeling(input, src_mask=None)
+ text_pre, text_rem, text_mas = self.Prediction(input,
+ f_res,
+ f_sub,
+ is_Train=True,
+ use_mlm=False)
+ return text_pre, text_pre, text_pre
+ elif training_step == 'LF_2':
+ # MLM
+ f_res, f_sub, mask_c = self.MLM(input, label_pos)
+ input = self.SequenceModeling(input, src_mask=None)
+ text_pre, text_rem, text_mas = self.Prediction(input,
+ f_res,
+ f_sub,
+ is_Train=True)
+ return text_pre, text_rem, text_mas
+ elif training_step == 'LA':
+ # MLM
+ f_res, f_sub, mask_c = self.MLM(input, label_pos)
+ # use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
+ # ratio controls the occluded number in a batch
+ ratio = 2
+ character_mask = torch.zeros_like(mask_c)
+ character_mask[0:b // ratio, :, :] = mask_c[0:b // ratio, :, :]
+ input = input * (1 - character_mask.transpose(1, 2))
+ # VRM
+ # transformer unit for VRM
+ input = self.SequenceModeling(input, src_mask=None)
+ # prediction layer for MLM and VSR
+ text_pre, text_rem, text_mas = self.Prediction(input,
+ f_res,
+ f_sub,
+ is_Train=True)
+ return text_pre, text_rem, text_mas
+ else: # VRM is only used in the testing stage
+ f_res = 0
+ f_sub = 0
+ contextual_feature = self.SequenceModeling(input, src_mask=None)
+ C = self.Prediction(contextual_feature,
+ f_res,
+ f_sub,
+ is_Train=False,
+ use_mlm=False)
+ C = C.transpose(1, 0) # (25, b, 38))
+ out_res = torch.zeros(nT, b, self.nclass).type_as(input.data)
+
+ out_length = torch.zeros(b).type_as(input.data)
+ now_step = 0
+ while 0 in out_length and now_step < nT:
+ tmp_result = C[now_step, :, :]
+ out_res[now_step] = tmp_result
+ tmp_result = tmp_result.topk(1)[1].squeeze(dim=1)
+ for j in range(b):
+ if out_length[j] == 0 and tmp_result[j] == 0:
+ out_length[j] = now_step + 1
+ now_step += 1
+ for j in range(0, b):
+ if int(out_length[j]) == 0:
+ out_length[j] = nT
+ start = 0
+ output = torch.zeros(int(out_length.sum()),
+ self.nclass).type_as(input.data)
+ for i in range(0, b):
+ cur_length = int(out_length[i])
+ output[start:start + cur_length] = out_res[0:cur_length, i, :]
+ start += cur_length
+
+ return output, out_length
+
+
+class VisionLANDecoder(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ n_head=None,
+ training_step='LA',
+ n_layers=3,
+ n_position=256,
+ max_text_length=25,
+ ):
+ super(VisionLANDecoder, self).__init__()
+ self.training_step = training_step
+ n_dim = in_channels
+ dim_feedforward = n_dim * 4
+ n_head = n_head if n_head is not None else n_dim // 32
+
+ self.MLM_VRM = MLM_VRM(
+ n_layers=n_layers,
+ n_position=n_position,
+ n_dim=n_dim,
+ n_head=n_head,
+ dim_feedforward=dim_feedforward,
+ max_text_length=max_text_length,
+ nclass=out_channels + 1,
+ )
+
+ def forward(self, x, data=None):
+ # MLM + VRM
+ if self.training:
+ label_pos = data[-2]
+ text_pre, text_rem, text_mas = self.MLM_VRM(x,
+ label_pos,
+ self.training_step,
+ is_Train=True)
+ return text_pre, text_rem, text_mas
+ else:
+ output, out_length = self.MLM_VRM(x,
+ None,
+ self.training_step,
+ is_Train=False)
+ return output, out_length
diff --git a/openrec/modeling/encoders/__init__.py b/openrec/modeling/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3496c4835807c12398b05f390c35d504c57fd1c
--- /dev/null
+++ b/openrec/modeling/encoders/__init__.py
@@ -0,0 +1,39 @@
+__all__ = ['build_encoder']
+
+
+def build_encoder(config):
+ # from .rec_mobilenet_v3 import MobileNetV3
+ from .focalsvtr import FocalSVTR
+ from .rec_hgnet import PPHGNet_small
+ from .rec_lcnetv3 import PPLCNetV3
+ from .rec_mv1_enhance import MobileNetV1Enhance
+ from .rec_nrtr_mtb import MTB
+ from .rec_resnet_31 import ResNet31
+ from .rec_resnet_45 import ResNet45
+ from .rec_resnet_fpn import ResNet_FPN
+ from .rec_resnet_vd import ResNet
+ from .resnet31_rnn import ResNet_ASTER
+ from .svtrnet import SVTRNet
+ from .svtrnet2dpos import SVTRNet2DPos
+ from .svtrv2 import SVTRv2
+ from .svtrv2_lnconv import SVTRv2LNConv
+ from .svtrv2_lnconv_two33 import SVTRv2LNConvTwo33
+ from .vit import ViT
+ from .cam_encoder import CAMEncoder
+ from .convnextv2 import ConvNeXtV2
+ from .autostr_encoder import AutoSTREncoder
+ from .nrtr_encoder import NRTREncoder
+ from .repvit import RepSVTREncoder
+ support_dict = [
+ 'MobileNetV1Enhance', 'ResNet31', 'MobileNetV3', 'PPLCNetV3',
+ 'PPHGNet_small', 'ResNet', 'MTB', 'SVTRNet', 'ResNet45', 'ViT',
+ 'SVTRNet2DPos', 'SVTRv2', 'FocalSVTR', 'ResNet_FPN', 'ResNet_ASTER',
+ 'SVTRv2LNConv', 'SVTRv2LNConvTwo33', 'CAMEncoder', 'ConvNeXtV2',
+ 'AutoSTREncoder', 'NRTREncoder', 'RepSVTREncoder'
+ ]
+
+ module_name = config.pop('name')
+ assert module_name in support_dict, Exception(
+ 'when encoder of rec model only support {}'.format(support_dict))
+ module_class = eval(module_name)(**config)
+ return module_class
diff --git a/openrec/modeling/encoders/__pycache__/__init__.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e13ac5ad32fce7b941b62ebeb8b451e45d8de11
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/__init__.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/autostr_encoder.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/autostr_encoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e91eeed05bcce4c6d6f91bcc0baea60cfd8b82da
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/autostr_encoder.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/cam_encoder.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/cam_encoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..158728934b5304dbcfe49c5fa14f4675baedba31
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/cam_encoder.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/convnextv2.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/convnextv2.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f34114c8f1ff39dda3cc048389c502c17ddb9c3
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/convnextv2.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/focalsvtr.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/focalsvtr.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51c0c37ec222c04c68f915cb68c89736d1d5a295
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/focalsvtr.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/nrtr_encoder.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/nrtr_encoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d08e7dcf8998a9e48a5aa9328294a47d9157814e
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/nrtr_encoder.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/rec_hgnet.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/rec_hgnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a01348c1f1822828179b4f83e506e40af241d9d
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/rec_hgnet.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/rec_lcnetv3.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/rec_lcnetv3.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f194b19501315a00f7764233876b18cd256b2ea
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/rec_lcnetv3.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/rec_mv1_enhance.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/rec_mv1_enhance.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a92e91cd39b8def28ffe934b96d0e364c8767eeb
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/rec_mv1_enhance.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/rec_nrtr_mtb.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/rec_nrtr_mtb.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..02a6bb2ffcaf5ef4fab49789ff3db713794a2554
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/rec_nrtr_mtb.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/rec_resnet_31.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/rec_resnet_31.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70301cadac3f64cd34a29f2589d16760ed0866ee
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/rec_resnet_31.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/rec_resnet_45.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/rec_resnet_45.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26b9328b42a3c5c1ba7325cfa26cbc5d8b45490b
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/rec_resnet_45.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/rec_resnet_fpn.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/rec_resnet_fpn.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c26e8010de1c7f3cb594791593373bddf6d959f
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/rec_resnet_fpn.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/rec_resnet_vd.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/rec_resnet_vd.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e62e0bb03d9331894fd7205dd5a0b4ca6b8fb47
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/rec_resnet_vd.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/repvit.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/repvit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d57f62de5cb2944c47b75924449c791ab6d1de5a
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/repvit.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/resnet31_rnn.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/resnet31_rnn.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9fad4acd5db8b7f455db9e9d5578868c2f04ae5
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/resnet31_rnn.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/svtrnet.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/svtrnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d8ceb2b6a0a38f415563a8f2c317684edb30ca0
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/svtrnet.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/svtrnet2dpos.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/svtrnet2dpos.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a59c710d709c02f0517cb8a485e27c6a5cc2aea4
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/svtrnet2dpos.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/svtrv2.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/svtrv2.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..698a53552eb043d60fc16078b10f7d22b79af3fc
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/svtrv2.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/svtrv2_lnconv.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/svtrv2_lnconv.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e58b6b55ff8e0fa2e3fe23bdc2094608af78afd
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/svtrv2_lnconv.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/svtrv2_lnconv_two33.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/svtrv2_lnconv_two33.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..079aa55e34a90613d0897f8138153c7e66b55639
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/svtrv2_lnconv_two33.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/__pycache__/vit.cpython-38.pyc b/openrec/modeling/encoders/__pycache__/vit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5886f13ca04d088297c56d861408a88bce0c2c7
Binary files /dev/null and b/openrec/modeling/encoders/__pycache__/vit.cpython-38.pyc differ
diff --git a/openrec/modeling/encoders/autostr_encoder.py b/openrec/modeling/encoders/autostr_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9af0cf8a1ab780fa0e8c665bf6cdda0eb39dad0e
--- /dev/null
+++ b/openrec/modeling/encoders/autostr_encoder.py
@@ -0,0 +1,327 @@
+from collections import OrderedDict
+import torch
+import torch.nn as nn
+
+
+class IdentityLayer(nn.Module):
+
+ def __init__(self):
+ super(IdentityLayer, self).__init__()
+
+ def forward(self, x):
+ return x
+
+ @staticmethod
+ def is_zero_layer():
+ return False
+
+
+class ZeroLayer(nn.Module):
+
+ def __init__(self, stride):
+ super(ZeroLayer, self).__init__()
+ self.stride = stride
+
+ def forward(self, x):
+ n, c, h, w = x.shape
+ h //= self.stride[0]
+ w //= self.stride[1]
+ device = x.device
+ padding = torch.zeros(n, c, h, w, device=device, requires_grad=False)
+ return padding
+
+ @staticmethod
+ def is_zero_layer():
+ return True
+
+ def get_flops(self, x):
+ return 0, self.forward(x)
+
+
+def get_same_padding(kernel_size):
+ if isinstance(kernel_size, tuple):
+ assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size
+ p1 = get_same_padding(kernel_size[0])
+ p2 = get_same_padding(kernel_size[1])
+ return p1, p2
+ assert isinstance(kernel_size,
+ int), 'kernel size should be either `int` or `tuple`'
+ assert kernel_size % 2 > 0, 'kernel size should be odd number'
+ return kernel_size // 2
+
+
+class MBInvertedConvLayer(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=(1, 1),
+ expand_ratio=6,
+ mid_channels=None):
+ super(MBInvertedConvLayer, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.expand_ratio = expand_ratio
+ self.mid_channels = mid_channels
+
+ feature_dim = round(
+ self.in_channels *
+ self.expand_ratio) if mid_channels is None else mid_channels
+ if self.expand_ratio == 1:
+ self.inverted_bottleneck = None
+ else:
+ self.inverted_bottleneck = nn.Sequential(
+ OrderedDict([
+ ('conv',
+ nn.Conv2d(self.in_channels,
+ feature_dim,
+ 1,
+ 1,
+ 0,
+ bias=False)),
+ ('bn', nn.BatchNorm2d(feature_dim)),
+ ('act', nn.ReLU6(inplace=True)),
+ ]))
+ pad = get_same_padding(self.kernel_size)
+ self.depth_conv = nn.Sequential(
+ OrderedDict([
+ ('conv',
+ nn.Conv2d(feature_dim,
+ feature_dim,
+ kernel_size,
+ stride,
+ pad,
+ groups=feature_dim,
+ bias=False)),
+ ('bn', nn.BatchNorm2d(feature_dim)),
+ ('act', nn.ReLU6(inplace=True)),
+ ]))
+ self.point_conv = nn.Sequential(
+ OrderedDict([
+ ('conv',
+ nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
+ ('bn', nn.BatchNorm2d(out_channels)),
+ ]))
+
+ def forward(self, x):
+ if self.inverted_bottleneck:
+ x = self.inverted_bottleneck(x)
+ x = self.depth_conv(x)
+ x = self.point_conv(x)
+ return x
+
+ @staticmethod
+ def is_zero_layer():
+ return False
+
+
+def conv_func_by_name(name):
+ name2ops = {
+ 'Identity': lambda in_C, out_C, S: IdentityLayer(),
+ 'Zero': lambda in_C, out_C, S: ZeroLayer(stride=S),
+ }
+ name2ops.update({
+ '3x3_MBConv1':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 1),
+ '3x3_MBConv2':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 2),
+ '3x3_MBConv3':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 3),
+ '3x3_MBConv4':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 4),
+ '3x3_MBConv5':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 5),
+ '3x3_MBConv6':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 6),
+ #######################################################################################
+ '5x5_MBConv1':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 1),
+ '5x5_MBConv2':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 2),
+ '5x5_MBConv3':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 3),
+ '5x5_MBConv4':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 4),
+ '5x5_MBConv5':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 5),
+ '5x5_MBConv6':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 6),
+ #######################################################################################
+ '7x7_MBConv1':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 1),
+ '7x7_MBConv2':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 2),
+ '7x7_MBConv3':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 3),
+ '7x7_MBConv4':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 4),
+ '7x7_MBConv5':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 5),
+ '7x7_MBConv6':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 6),
+ })
+ return name2ops[name]
+
+
+def build_candidate_ops(candidate_ops, in_channels, out_channels, stride,
+ ops_order):
+ if candidate_ops is None:
+ raise ValueError('please specify a candidate set')
+
+ name2ops = {
+ 'Identity':
+ lambda in_C, out_C, S: IdentityLayer(in_C, out_C, ops_order=ops_order),
+ 'Zero':
+ lambda in_C, out_C, S: ZeroLayer(stride=S),
+ }
+ # add MBConv layers
+ name2ops.update({
+ '3x3_MBConv1':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 1),
+ '3x3_MBConv2':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 2),
+ '3x3_MBConv3':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 3),
+ '3x3_MBConv4':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 4),
+ '3x3_MBConv5':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 5),
+ '3x3_MBConv6':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 6),
+ #######################################################################################
+ '5x5_MBConv1':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 1),
+ '5x5_MBConv2':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 2),
+ '5x5_MBConv3':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 3),
+ '5x5_MBConv4':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 4),
+ '5x5_MBConv5':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 5),
+ '5x5_MBConv6':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 6),
+ #######################################################################################
+ '7x7_MBConv1':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 1),
+ '7x7_MBConv2':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 2),
+ '7x7_MBConv3':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 3),
+ '7x7_MBConv4':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 4),
+ '7x7_MBConv5':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 5),
+ '7x7_MBConv6':
+ lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 6),
+ })
+
+ return [
+ name2ops[name](in_channels, out_channels, stride)
+ for name in candidate_ops
+ ]
+
+
+class MobileInvertedResidualBlock(nn.Module):
+
+ def __init__(self, mobile_inverted_conv, shortcut):
+ super(MobileInvertedResidualBlock, self).__init__()
+ self.mobile_inverted_conv = mobile_inverted_conv
+ self.shortcut = shortcut
+
+ def forward(self, x):
+ if self.mobile_inverted_conv.is_zero_layer():
+ res = x
+ elif self.shortcut is None or self.shortcut.is_zero_layer():
+ res = self.mobile_inverted_conv(x)
+ else:
+ conv_x = self.mobile_inverted_conv(x)
+ skip_x = self.shortcut(x)
+ res = skip_x + conv_x
+ return res
+
+
+class AutoSTREncoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_dim=256,
+ with_lstm=True,
+ stride_stages='[(2, 2), (2, 2), (2, 1), (2, 1), (2, 1)]',
+ n_cell_stages=[3, 3, 3, 3, 3],
+ conv_op_ids=[5, 5, 5, 5, 5, 5, 5, 6, 6, 5, 4, 3, 4, 6, 6],
+ **kwargs):
+ super().__init__()
+ self.first_conv = nn.Sequential(
+ nn.Conv2d(in_channels,
+ 32,
+ kernel_size=(3, 3),
+ stride=1,
+ padding=1,
+ bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True))
+ stride_stages = eval(stride_stages)
+ width_stages = [32, 64, 128, 256, 512]
+ conv_candidates = [
+ '5x5_MBConv1', '5x5_MBConv3', '5x5_MBConv6', '3x3_MBConv1',
+ '3x3_MBConv3', '3x3_MBConv6', 'Zero'
+ ]
+
+ assert len(conv_op_ids) == sum(n_cell_stages)
+ blocks = []
+ input_channel = 32
+ for width, n_cell, s in zip(width_stages, n_cell_stages,
+ stride_stages):
+ for i in range(n_cell):
+ if i == 0:
+ stride = s
+ else:
+ stride = (1, 1)
+ block_i = len(blocks)
+ conv_op = conv_func_by_name(
+ conv_candidates[conv_op_ids[block_i]])(input_channel,
+ width, stride)
+ if stride == (1, 1) and input_channel == width:
+ shortcut = IdentityLayer()
+ else:
+ shortcut = None
+ inverted_residual_block = MobileInvertedResidualBlock(
+ conv_op, shortcut)
+ blocks.append(inverted_residual_block)
+ input_channel = width
+ self.out_channels = input_channel
+
+ self.blocks = nn.ModuleList(blocks)
+
+ # with_lstm = False
+ self.with_lstm = with_lstm
+ if with_lstm:
+ self.rnn = nn.LSTM(input_channel,
+ out_dim // 2,
+ bidirectional=True,
+ num_layers=2,
+ batch_first=True)
+ self.out_channels = out_dim
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight,
+ mode='fan_out',
+ nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.first_conv(x)
+ for block in self.blocks:
+ x = block(x)
+ cnn_feat = x.squeeze(dim=2)
+ cnn_feat = cnn_feat.transpose(2, 1)
+ if self.with_lstm:
+ rnn_feat, _ = self.rnn(cnn_feat)
+ return rnn_feat
+ else:
+ return cnn_feat
diff --git a/openrec/modeling/encoders/cam_encoder.py b/openrec/modeling/encoders/cam_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bad258a52bae046d4d887c95e045ab9e9adc5b7
--- /dev/null
+++ b/openrec/modeling/encoders/cam_encoder.py
@@ -0,0 +1,760 @@
+"""This code is refer from:
+https://github.com/MelosY/CAM
+"""
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.init import trunc_normal_
+
+from .convnextv2 import ConvNeXtV2, Block, LayerNorm
+from .svtrv2_lnconv_two33 import SVTRv2LNConvTwo33
+
+
+class Swish(nn.Module):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class UNetBlock(nn.Module):
+
+ def __init__(self, cin, cout, bn2d, stride, deformable=False):
+ """
+ a UNet block with 2x up sampling
+ """
+ super().__init__()
+ stride_h, stride_w = stride
+ if stride_h == 1:
+ kernel_h = 1
+ padding_h = 0
+ elif stride_h == 2:
+ kernel_h = 4
+ padding_h = 1
+ elif stride_h == 4:
+ kernel_h = 4
+ padding_h = 0
+
+ if stride_w == 1:
+ kernel_w = 1
+ padding_w = 0
+ elif stride_w == 2:
+ kernel_w = 4
+ padding_w = 1
+ elif stride_w == 4:
+ kernel_w = 4
+ padding_w = 0
+
+ conv = nn.Conv2d
+
+ self.up_sample = nn.ConvTranspose2d(cin,
+ cin,
+ kernel_size=(kernel_h, kernel_w),
+ stride=(stride_h, stride_w),
+ padding=(padding_h, padding_w),
+ bias=True)
+ self.conv = nn.Sequential(
+ conv(cin, cin, kernel_size=3, stride=1, padding=1, bias=False),
+ bn2d(cin),
+ nn.ReLU6(inplace=True),
+ conv(cin, cout, kernel_size=3, stride=1, padding=1, bias=False),
+ bn2d(cout),
+ )
+
+ def forward(self, x):
+ x = self.up_sample(x)
+ return self.conv(x)
+
+
+class DepthWiseUNetBlock(nn.Module):
+
+ def __init__(self, cin, cout, bn2d, stride, deformable=False):
+ """
+ a UNet block with 2x up sampling
+ """
+ super().__init__()
+ stride_h, stride_w = stride
+ if stride_h == 1:
+ kernel_h = 1
+ padding_h = 0
+ elif stride_h == 2:
+ kernel_h = 4
+ padding_h = 1
+ elif stride_h == 4:
+ kernel_h = 4
+ padding_h = 0
+
+ if stride_w == 1:
+ kernel_w = 1
+ padding_w = 0
+ elif stride_w == 2:
+ kernel_w = 4
+ padding_w = 1
+ elif stride_w == 4:
+ kernel_w = 4
+ padding_w = 0
+
+ self.up_sample = nn.ConvTranspose2d(cin,
+ cin,
+ kernel_size=(kernel_h, kernel_w),
+ stride=(stride_h, stride_w),
+ padding=(padding_h, padding_w),
+ bias=True)
+ self.conv = nn.Sequential(
+ nn.Conv2d(cin,
+ cin,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=cin),
+ nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
+ bias=False),
+ bn2d(cin),
+ nn.ReLU6(inplace=True),
+ nn.Conv2d(cin,
+ cin,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=cin),
+ nn.Conv2d(cin,
+ cout,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ bn2d(cout),
+ )
+
+ def forward(self, x):
+ x = self.up_sample(x)
+ return self.conv(x)
+
+
+class SFTLayer(nn.Module):
+
+ def __init__(self, dim_in, dim_out):
+ super(SFTLayer, self).__init__()
+ self.SFT_scale_conv0 = nn.Linear(
+ dim_in,
+ dim_in,
+ )
+ self.SFT_scale_conv1 = nn.Linear(
+ dim_in,
+ dim_out,
+ )
+ self.SFT_shift_conv0 = nn.Linear(
+ dim_in,
+ dim_in,
+ )
+ self.SFT_shift_conv1 = nn.Linear(
+ dim_in,
+ dim_out,
+ )
+
+ def forward(self, x):
+ # x[0]: fea; x[1]: cond
+ scale = self.SFT_scale_conv1(
+ F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.1, inplace=True))
+ shift = self.SFT_shift_conv1(
+ F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.1, inplace=True))
+ return x[0] * (scale + 1) + shift
+
+
+class MoreUNetBlock(nn.Module):
+
+ def __init__(self, cin, cout, bn2d, stride, deformable=False):
+ """
+ a UNet block with 2x up sampling
+ """
+ super().__init__()
+ stride_h, stride_w = stride
+ if stride_h == 1:
+ kernel_h = 1
+ padding_h = 0
+ elif stride_h == 2:
+ kernel_h = 4
+ padding_h = 1
+ elif stride_h == 4:
+ kernel_h = 4
+ padding_h = 0
+
+ if stride_w == 1:
+ kernel_w = 1
+ padding_w = 0
+ elif stride_w == 2:
+ kernel_w = 4
+ padding_w = 1
+ elif stride_w == 4:
+ kernel_w = 4
+ padding_w = 0
+
+ self.up_sample = nn.ConvTranspose2d(cin,
+ cin,
+ kernel_size=(kernel_h, kernel_w),
+ stride=(stride_h, stride_w),
+ padding=(padding_h, padding_w),
+ bias=True)
+ self.conv = nn.Sequential(
+ nn.Conv2d(cin,
+ cin,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=cin),
+ nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
+ bias=False), bn2d(cin), nn.ReLU6(inplace=True),
+ nn.Conv2d(cin,
+ cin,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=cin),
+ nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
+ bias=False), bn2d(cin), nn.ReLU6(inplace=True),
+ nn.Conv2d(cin,
+ cin,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=cin),
+ nn.Conv2d(cin,
+ cout,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False), bn2d(cout), nn.ReLU6(inplace=True),
+ nn.Conv2d(cout,
+ cout,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=cout),
+ nn.Conv2d(cout,
+ cout,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False), bn2d(cout))
+
+ def forward(self, x):
+ x = self.up_sample(x)
+ return self.conv(x)
+
+
+class BinaryDecoder(nn.Module):
+
+ def __init__(self,
+ dim,
+ num_classes,
+ strides,
+ use_depthwise_unet=False,
+ use_more_unet=False,
+ binary_loss_type='DiceLoss') -> None:
+ super().__init__()
+
+ channels = [dim // 2**i for i in range(4)]
+ self.linear_enc2binary = nn.Sequential(
+ nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1),
+ nn.SyncBatchNorm(dim),
+ )
+ self.strides = strides
+ self.use_deformable = False
+ self.binary_decoder = nn.ModuleList()
+ unet = DepthWiseUNetBlock if use_depthwise_unet else UNetBlock
+ unet = MoreUNetBlock if use_more_unet else unet
+
+ for i in range(3):
+ up_sample_stride = self.strides[::-1][i]
+ cin, cout = channels[i], channels[i + 1]
+ self.binary_decoder.append(
+ unet(cin, cout, nn.SyncBatchNorm, up_sample_stride,
+ self.use_deformable))
+
+ last_stride = (self.strides[0][0] // 2, self.strides[0][1] // 2)
+ self.binary_decoder.append(
+ unet(cout, cout, nn.SyncBatchNorm, last_stride,
+ self.use_deformable))
+
+ if binary_loss_type == 'CrossEntropyDiceLoss' or binary_loss_type == 'BanlanceMultiClassCrossEntropyLoss':
+ segm_num_cls = num_classes - 2
+ else:
+ segm_num_cls = num_classes - 3
+ self.binary_pred = nn.Conv2d(channels[-1],
+ segm_num_cls,
+ kernel_size=1,
+ stride=1,
+ bias=True)
+
+ def patchify(self, imgs):
+ """
+ imgs: (N, 3, H, W)
+ x: (N, L, patch_size**2 *3)
+ """
+ p_h, p_w = self.strides[0]
+ p_h = p_h // 2
+ p_w = p_w // 2
+ h = imgs.shape[2] // p_h
+ w = imgs.shape[3] // p_w
+
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p_h, w, p_w))
+ x = torch.einsum('nchpwq->nhwpqc', x)
+ x = x.reshape(shape=(imgs.shape[0], h * w, p_h * p_w * 3))
+ return x
+
+ def unpatchify(self, x):
+ """
+ x: (N, patch_size**2, h, w)
+ imgs: (N, 3, H, W)
+ """
+ p_h, p_w = self.strides[0]
+ p_h = p_h // 2
+ p_w = p_w // 2
+ _, _, h, w = x.shape
+ assert p_h * p_w == x.shape[1]
+
+ x = x.permute(0, 2, 3, 1) # [N, h, w, 4*4]
+ x = x.reshape(shape=(x.shape[0], h, w, p_h, p_w))
+ x = torch.einsum('nhwpq->nhpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], h * p_h, w * p_w))
+ return imgs
+
+ def forward(self, x, time=None):
+ """
+ x: the encoder feat to init the query for binary prediction, usually this is equal to the `img`.
+ img: the encoder feat.
+ txt: the unnormmed text to get the length of predicted words.
+ txt_feat: the text feat before character prediction.
+ xs: the encoder feat from different stages
+ """
+
+ binary_feats = []
+ x = self.linear_enc2binary(x)
+ binary_feats.append(x.clone())
+
+ for i, d in enumerate(self.binary_decoder):
+
+ x = d(x)
+ binary_feats.append(x.clone())
+ #return None,binary_feats
+ x = self.binary_pred(x)
+
+ if self.training:
+ return x, binary_feats
+ else:
+ # return torch.sigmoid(x), binary_feat
+ return x.softmax(1), binary_feats
+
+
+class LayerNormProxy(nn.Module):
+
+ def __init__(self, dim):
+
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 3, 1)
+ x = self.norm(x)
+ return x.permute(0, 3, 1, 2)
+
+
+class DAttentionFuse(nn.Module):
+
+ def __init__(
+ self,
+ q_size=(4, 32),
+ kv_size=(4, 32),
+ n_heads=8,
+ n_head_channels=80,
+ n_groups=4,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ stride=2,
+ offset_range_factor=2,
+ use_pe=True,
+ stage_idx=0,
+ ):
+ '''
+ stage_idx from 2 to 3
+ '''
+
+ super().__init__()
+ self.n_head_channels = n_head_channels
+ self.scale = self.n_head_channels**-0.5
+ self.n_heads = n_heads
+ self.q_h, self.q_w = q_size
+ self.kv_h, self.kv_w = kv_size
+ self.nc = n_head_channels * n_heads
+ self.n_groups = n_groups
+ self.n_group_channels = self.nc // self.n_groups
+ self.n_group_heads = self.n_heads // self.n_groups
+ self.use_pe = use_pe
+ self.offset_range_factor = offset_range_factor
+ ksizes = [9, 7, 5, 3]
+ kk = ksizes[stage_idx]
+
+ self.conv_offset = nn.Sequential(
+ nn.Conv2d(2 * self.n_group_channels,
+ 2 * self.n_group_channels,
+ kk,
+ stride,
+ kk // 2,
+ groups=self.n_group_channels),
+ LayerNormProxy(2 * self.n_group_channels), nn.GELU(),
+ nn.Conv2d(2 * self.n_group_channels, 2, 1, 1, 0, bias=False))
+
+ self.proj_q = nn.Conv2d(self.nc,
+ self.nc,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.proj_k = nn.Conv2d(self.nc,
+ self.nc,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.proj_v = nn.Conv2d(self.nc,
+ self.nc,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.proj_out = nn.Conv2d(self.nc,
+ self.nc,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.proj_drop = nn.Dropout(proj_drop, inplace=True)
+ self.attn_drop = nn.Dropout(attn_drop, inplace=True)
+
+ if self.use_pe:
+ self.rpe_table = nn.Parameter(
+ torch.zeros(self.n_heads, self.kv_h * 2 - 1,
+ self.kv_w * 2 - 1))
+ trunc_normal_(self.rpe_table, std=0.01)
+ else:
+ self.rpe_table = None
+
+ @torch.no_grad()
+ def _get_ref_points(self, H_key, W_key, B, dtype, device):
+
+ ref_y, ref_x = torch.meshgrid(
+ torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype,
+ device=device),
+ torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype,
+ device=device))
+ ref = torch.stack((ref_y, ref_x), -1)
+ ref[..., 1].div_(W_key).mul_(2).sub_(1)
+ ref[..., 0].div_(H_key).mul_(2).sub_(1)
+ ref = ref[None, ...].expand(B * self.n_groups, -1, -1,
+ -1) # B * g H W 2
+ return ref
+
+ def forward(self, x, y):
+ B, C, H, W = x.size()
+ dtype, device = x.dtype, x.device
+
+ q_off = torch.cat(
+ (x, y), dim=1
+ ).reshape(B, self.n_groups, 2 * self.n_group_channels, H, W).flatten(
+ 0, 1
+ ) #einops.rearrange(torch.cat((x,y),dim=1), 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=2*self.n_group_channels)
+
+ offset = self.conv_offset(q_off) # B * g 2 Hg Wg
+ Hk, Wk = offset.size(2), offset.size(3)
+ n_sample = Hk * Wk
+ if self.offset_range_factor > 0:
+ offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk],
+ device=device).reshape(1, 2, 1, 1)
+ offset = offset.tanh().mul(offset_range).mul(
+ self.offset_range_factor)
+
+ offset = offset.permute(
+ 0, 2, 3, 1) #einops.rearrange(offset, 'b p h w -> b h w p')
+ reference = self._get_ref_points(Hk, Wk, B, dtype, device)
+
+ if self.offset_range_factor >= 0:
+ pos = offset + reference
+ else:
+ pos = (offset + reference).tanh()
+
+ q = self.proj_q(y)
+ x_sampled = F.grid_sample(
+ input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),
+ grid=pos[..., (1, 0)], # y, x -> x, y
+ mode='bilinear',
+ align_corners=False) # B * g, Cg, Hg, Wg
+
+ x_sampled = x_sampled.reshape(B, C, 1, n_sample)
+
+ q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
+ k = self.proj_k(x_sampled).reshape(B * self.n_heads,
+ self.n_head_channels, n_sample)
+ v = self.proj_v(x_sampled).reshape(B * self.n_heads,
+ self.n_head_channels, n_sample)
+
+ attn = torch.einsum('b c m, b c n -> b m n', q, k) # B * h, HW, Ns
+ attn = attn.mul(self.scale)
+
+ if self.use_pe:
+ rpe_table = self.rpe_table
+ rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
+
+ q_grid = self._get_ref_points(H, W, B, dtype, device)
+
+ displacement = (
+ q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) -
+ pos.reshape(B * self.n_groups, n_sample,
+ 2).unsqueeze(1)).mul(0.5)
+
+ attn_bias = F.grid_sample(input=rpe_bias.reshape(
+ B * self.n_groups, self.n_group_heads, 2 * H - 1, 2 * W - 1),
+ grid=displacement[..., (1, 0)],
+ mode='bilinear',
+ align_corners=False)
+
+ attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
+
+ attn = attn + attn_bias
+
+ attn = F.softmax(attn, dim=2)
+ attn = self.attn_drop(attn)
+
+ out = torch.einsum('b m n, b c n -> b c m', attn, v)
+ out = out.reshape(B, C, H, W)
+ out = self.proj_drop(self.proj_out(out))
+
+ return out, pos.reshape(B, self.n_groups, Hk, Wk,
+ 2), reference.reshape(B, self.n_groups, Hk, Wk,
+ 2)
+
+
+class FuseModel(nn.Module):
+
+ def __init__(self,
+ dim,
+ deform_stride=2,
+ stage_idx=2,
+ k_size=[(2, 2), (2, 1), (2, 1), (1, 1)],
+ q_size=(2, 32)):
+ super().__init__()
+
+ channels = [dim // 2**i for i in range(4)]
+
+ refine_conv = nn.Conv2d
+ self.deform_stride = deform_stride
+
+ in_out_ch = [(-1, -2), (-2, -3), (-3, -4), (-4, -4)]
+
+ self.binary_condition_layer = DAttentionFuse(q_size=q_size,
+ kv_size=q_size,
+ stride=self.deform_stride,
+ n_head_channels=dim // 8,
+ stage_idx=stage_idx)
+
+ self.binary2refine_linear_norm = nn.ModuleList()
+ for i in range(len(k_size)):
+ self.binary2refine_linear_norm.append(
+ nn.Sequential(
+ Block(dim=channels[in_out_ch[i][0]]),
+ LayerNorm(channels[in_out_ch[i][0]],
+ eps=1e-6,
+ data_format='channels_first'),
+ refine_conv(channels[in_out_ch[i][0]],
+ channels[in_out_ch[i][1]],
+ kernel_size=k_size[i],
+ stride=k_size[i])), # [8, 32]
+ )
+
+ def forward(self, recog_feat, binary_feats, dec_in=None):
+ multi_feat = []
+ binary_feat = binary_feats[-1]
+ for i in range(len(self.binary2refine_linear_norm)):
+ binary_feat = self.binary2refine_linear_norm[i](binary_feat)
+ multi_feat.append(binary_feat)
+ binary_feat = binary_feat + binary_feats[0]
+ multi_feat[3] += binary_feats[0]
+ binary_refined_feat, pos, _ = self.binary_condition_layer(
+ recog_feat, binary_feat)
+ binary_refined_feat = binary_refined_feat + binary_feat
+ return binary_refined_feat, binary_feat
+
+
+class CAMEncoder(nn.Module):
+ """
+
+ Args:
+ in_chans (int): Number of input image channels. Default: 3
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
+
+ """
+
+ def __init__(self,
+ in_channels=3,
+ encoder_config={'name': 'ConvNeXtV2'},
+ nb_classes=71,
+ strides=[(4, 4), (2, 1), (2, 1), (1, 1)],
+ k_size=[(2, 2), (2, 1), (2, 1), (1, 1)],
+ q_size=[2, 32],
+ deform_stride=2,
+ stage_idx=2,
+ use_depthwise_unet=True,
+ use_more_unet=False,
+ binary_loss_type='BanlanceMultiClassCrossEntropyLoss',
+ mid_size=True,
+ d_embedding=384):
+ super().__init__()
+ encoder_name = encoder_config.pop('name')
+ encoder_config['in_channels'] = in_channels
+ self.backbone = eval(encoder_name)(**encoder_config)
+ dim = self.backbone.out_channels
+ self.mid_size = mid_size
+ if self.mid_size:
+ self.enc_downsample = nn.Sequential(
+ nn.Conv2d(dim, dim // 2, kernel_size=1, stride=1),
+ nn.SyncBatchNorm(dim // 2),
+ #nn.ReLU6(inplace=True),
+ nn.Conv2d(dim // 2,
+ dim // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=dim // 2),
+ nn.Conv2d(dim // 2,
+ dim // 2,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ nn.SyncBatchNorm(dim // 2),
+ )
+ dim = dim // 2
+ # recognition decoder
+ self.linear_enc2recog = nn.Sequential(
+ nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=1,
+ stride=1,
+ ),
+ nn.SyncBatchNorm(dim),
+ #nn.ReLU6(inplace=True),
+ nn.Conv2d(dim,
+ dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=dim),
+ nn.Conv2d(dim,
+ dim,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ nn.SyncBatchNorm(dim),
+ )
+ else:
+ self.linear_enc2recog = nn.Sequential(
+ nn.Conv2d(dim, dim // 2, kernel_size=1, stride=1),
+ nn.SyncBatchNorm(dim // 2),
+ #nn.ReLU6(inplace=True),
+ nn.Conv2d(dim // 2, dim, kernel_size=3, stride=1, padding=1),
+ nn.SyncBatchNorm(dim),
+ )
+
+ self.linear_norm = nn.Sequential(
+ nn.Linear(dim, d_embedding),
+ nn.LayerNorm(d_embedding, eps=1e-6),
+ )
+ self.out_channels = d_embedding
+
+ self.binary_decoder = BinaryDecoder(
+ dim,
+ nb_classes,
+ strides,
+ use_depthwise_unet=use_depthwise_unet,
+ use_more_unet=use_more_unet,
+ binary_loss_type=binary_loss_type)
+ self.fuse_model = FuseModel(dim,
+ deform_stride=deform_stride,
+ stage_idx=stage_idx,
+ k_size=k_size,
+ q_size=q_size)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, (nn.Conv2d, nn.Linear)) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ if isinstance(m, nn.ConvTranspose2d):
+ nn.init.kaiming_normal_(m.weight,
+ mode='fan_out',
+ nonlinearity='relu')
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0.)
+ elif isinstance(m, nn.LayerNorm):
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.SyncBatchNorm):
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.BatchNorm2d):
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1.0)
+
+ def no_weight_decay(self):
+ return {}
+
+ def forward(self, x):
+ output = {}
+ enc_feat = self.backbone(x)
+ if self.mid_size:
+ enc_feat = self.enc_downsample(enc_feat)
+ output['enc_feat'] = enc_feat
+
+ # binary mask
+ pred_binary, binary_feats = self.binary_decoder(enc_feat)
+ output['pred_binary'] = pred_binary
+
+ reg_feat = self.linear_enc2recog(enc_feat)
+ B, C, H, W = reg_feat.shape
+ last_feat, binary_feat = self.fuse_model(reg_feat, binary_feats)
+
+ dec_in = last_feat.reshape(B, C, H * W).permute(0, 2, 1)
+ dec_in = self.linear_norm(dec_in)
+
+ output['refined_feat'] = dec_in
+ output['binary_feat'] = binary_feats[-1]
+ return output
diff --git a/openrec/modeling/encoders/convnextv2.py b/openrec/modeling/encoders/convnextv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef73e2cf4351842b3f39cab2d91695841a43226e
--- /dev/null
+++ b/openrec/modeling/encoders/convnextv2.py
@@ -0,0 +1,213 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.init import trunc_normal_
+
+from openrec.modeling.common import DropPath
+
+
+class LayerNorm(nn.Module):
+ """ LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self,
+ normalized_shape,
+ eps=1e-6,
+ data_format='channels_last'):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ['channels_last', 'channels_first']:
+ raise NotImplementedError
+ self.normalized_shape = (normalized_shape, )
+
+ def forward(self, x):
+ if self.data_format == 'channels_last':
+ return F.layer_norm(x, self.normalized_shape, self.weight,
+ self.bias, self.eps)
+ elif self.data_format == 'channels_first':
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+class GRN(nn.Module):
+ """ GRN (Global Response Normalization) layer
+ """
+
+ def __init__(self, dim):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
+
+ def forward(self, inputs, mask=None):
+ x = inputs
+ if mask is not None:
+ x = x * (1. - mask)
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma * (inputs * Nx) + self.beta + inputs
+
+
+class Block(nn.Module):
+ """ ConvNeXtV2 Block.
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ """
+
+ def __init__(self, dim, drop_path=0.):
+ super().__init__()
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3,
+ groups=dim) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(
+ dim,
+ 4 * dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.grn = GRN(4 * dim)
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ input = x
+ x = self.dwconv(x.contiguous())
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.grn(x)
+ x = self.pwconv2(x)
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+
+class ConvNeXtV2(nn.Module):
+ """ ConvNeXt V2
+
+ Args:
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
+
+ """
+
+ def __init__(
+ self,
+ in_channels=3,
+ depths=[3, 3, 9, 3],
+ dims=[96, 192, 384, 768],
+ drop_path_rate=0.,
+ strides=[(4, 4), (2, 2), (2, 2), (2, 2)],
+ out_channels=256,
+ last_stage=False,
+ feat2d=False,
+ **kwargs,
+ ):
+ super().__init__()
+ self.strides = strides
+ self.depths = depths
+ self.downsample_layers = nn.ModuleList(
+ ) # stem and 3 intermediate downsampling conv layers
+ stem = nn.Sequential(
+ nn.Conv2d(in_channels,
+ dims[0],
+ kernel_size=strides[0],
+ stride=strides[0]),
+ LayerNorm(dims[0], eps=1e-6, data_format='channels_first'))
+ self.downsample_layers.append(stem)
+ for i in range(3):
+ downsample_layer = nn.Sequential(
+ LayerNorm(dims[i], eps=1e-6, data_format='channels_first'),
+ nn.Conv2d(dims[i],
+ dims[i + 1],
+ kernel_size=strides[i + 1],
+ stride=strides[i + 1]),
+ )
+ self.downsample_layers.append(downsample_layer)
+
+ self.stages = nn.ModuleList(
+ ) # 4 feature resolution stages, each consisting of multiple residual blocks
+ dp_rates = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ]
+ cur = 0
+ for i in range(4):
+ stage = nn.Sequential(*[
+ Block(dim=dims[i], drop_path=dp_rates[cur + j])
+ for j in range(depths[i])
+ ])
+ self.stages.append(stage)
+ cur += depths[i]
+ self.out_channels = dims[-1]
+ self.last_stage = last_stage
+ self.feat2d = feat2d
+ if last_stage:
+ self.out_channels = out_channels
+ self.last_conv = nn.Linear(dims[-1], self.out_channels, bias=False)
+ self.hardswish = nn.Hardswish()
+ self.dropout = nn.Dropout(p=0.1)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, (nn.Conv2d, nn.Linear)) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.SyncBatchNorm):
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.BatchNorm2d):
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1.0)
+
+ def no_weight_decay(self):
+ return {}
+
+ def forward(self, x):
+ feats = []
+ for i in range(4):
+ x = self.downsample_layers[i](x)
+ x = self.stages[i](x)
+ feats.append(x)
+
+ if self.last_stage:
+ x = x.mean(2).transpose(1, 2)
+ x = self.last_conv(x)
+ x = self.hardswish(x)
+ x = self.dropout(x)
+ return x
+ if self.feat2d:
+ return x
+ return feats
diff --git a/openrec/modeling/encoders/focalsvtr.py b/openrec/modeling/encoders/focalsvtr.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1bb8f92f8bd2a22996c59701cc3aea077b573f4
--- /dev/null
+++ b/openrec/modeling/encoders/focalsvtr.py
@@ -0,0 +1,631 @@
+# --------------------------------------------------------
+# FocalNets -- Focal Modulation Networks
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Jianwei Yang (jianwyan@microsoft.com)
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+from torch.nn.init import trunc_normal_
+
+from openrec.modeling.common import DropPath, Mlp
+from openrec.modeling.encoders.svtrnet import ConvBNLayer
+
+
+class FocalModulation(nn.Module):
+
+ def __init__(self,
+ dim,
+ focal_window,
+ focal_level,
+ max_kh=None,
+ focal_factor=2,
+ bias=True,
+ proj_drop=0.0,
+ use_postln_in_modulation=False,
+ normalize_modulator=False):
+ super().__init__()
+
+ self.dim = dim
+ self.focal_window = focal_window
+ self.focal_level = focal_level
+ self.focal_factor = focal_factor
+ self.use_postln_in_modulation = use_postln_in_modulation
+ self.normalize_modulator = normalize_modulator
+
+ self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=bias)
+ self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias)
+
+ self.act = nn.GELU()
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.focal_layers = nn.ModuleList()
+
+ self.kernel_sizes = []
+ for k in range(self.focal_level):
+ kernel_size = self.focal_factor * k + self.focal_window
+ if max_kh is not None:
+ k_h, k_w = [min(kernel_size, max_kh), kernel_size]
+ kernel_size = [k_h, k_w]
+ padding = [k_h // 2, k_w // 2]
+ else:
+ padding = kernel_size // 2
+ self.focal_layers.append(
+ nn.Sequential(
+ nn.Conv2d(dim,
+ dim,
+ kernel_size=kernel_size,
+ stride=1,
+ groups=dim,
+ padding=padding,
+ bias=False),
+ nn.GELU(),
+ ))
+ self.kernel_sizes.append(kernel_size)
+ if self.use_postln_in_modulation:
+ self.ln = nn.LayerNorm(dim)
+
+ def forward(self, x):
+ """
+ Args:
+ x: input features with shape of (B, H, W, C)
+ """
+ C = x.shape[-1]
+
+ # pre linear projection
+ x = self.f(x).permute(0, 3, 1, 2).contiguous()
+ q, ctx, self.gates = torch.split(x, (C, C, self.focal_level + 1), 1)
+
+ # context aggreation
+ ctx_all = 0
+ for l in range(self.focal_level):
+ ctx = self.focal_layers[l](ctx)
+ ctx_all = ctx_all + ctx * self.gates[:, l:l + 1]
+ ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
+ ctx_all = ctx_all + ctx_global * self.gates[:, self.focal_level:]
+
+ # normalize context
+ if self.normalize_modulator:
+ ctx_all = ctx_all / (self.focal_level + 1)
+
+ # focal modulation
+ self.modulator = self.h(ctx_all)
+ x_out = q * self.modulator
+ x_out = x_out.permute(0, 2, 3, 1).contiguous()
+ if self.use_postln_in_modulation:
+ x_out = self.ln(x_out)
+
+ # post linear porjection
+ x_out = self.proj(x_out)
+ x_out = self.proj_drop(x_out)
+ return x_out
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+
+ flops += N * self.dim * (self.dim * 2 + (self.focal_level + 1))
+
+ # focal convolution
+ for k in range(self.focal_level):
+ flops += N * (self.kernel_sizes[k]**2 + 1) * self.dim
+
+ # global gating
+ flops += N * 1 * self.dim
+
+ # self.linear
+ flops += N * self.dim * (self.dim + 1)
+
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+
+class FocalNetBlock(nn.Module):
+ r"""Focal Modulation Network Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ drop (float, optional): Dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ focal_level (int): Number of focal levels.
+ focal_window (int): Focal window size at first focal level
+ use_layerscale (bool): Whether use layerscale
+ layerscale_value (float): Initial layerscale value
+ use_postln (bool): Whether use layernorm after modulation
+ """
+
+ def __init__(
+ self,
+ dim,
+ input_resolution=None,
+ mlp_ratio=4.0,
+ drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ focal_level=1,
+ focal_window=3,
+ max_kh=None,
+ use_layerscale=False,
+ layerscale_value=1e-4,
+ use_postln=False,
+ use_postln_in_modulation=False,
+ normalize_modulator=False,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.mlp_ratio = mlp_ratio
+
+ self.focal_window = focal_window
+ self.focal_level = focal_level
+ self.use_postln = use_postln
+
+ self.norm1 = norm_layer(dim)
+ self.modulation = FocalModulation(
+ dim,
+ proj_drop=drop,
+ focal_window=focal_window,
+ focal_level=self.focal_level,
+ max_kh=max_kh,
+ use_postln_in_modulation=use_postln_in_modulation,
+ normalize_modulator=normalize_modulator,
+ )
+
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop)
+
+ self.gamma_1 = 1.0
+ self.gamma_2 = 1.0
+ if use_layerscale:
+ self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)),
+ requires_grad=True)
+ self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)),
+ requires_grad=True)
+
+ self.H = None
+ self.W = None
+
+ def forward(self, x):
+ H, W = self.H, self.W
+ B, L, C = x.shape
+ shortcut = x
+
+ # Focal Modulation
+ x = x if self.use_postln else self.norm1(x)
+ x = x.view(B, H, W, C)
+ x = self.modulation(x).view(B, H * W, C)
+ x = x if not self.use_postln else self.norm1(x)
+
+ # FFN
+ x = shortcut + self.drop_path(self.gamma_1 * x)
+ x = x + self.drop_path(self.gamma_2 * (self.norm2(
+ self.mlp(x)) if self.use_postln else self.mlp(self.norm2(x))))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, input_resolution={self.input_resolution}, ' f'mlp_ratio={self.mlp_ratio}'
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+
+ # W-MSA/SW-MSA
+ flops += self.modulation.flops(H * W)
+
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """A basic Focal Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ focal_level (int): Number of focal levels
+ focal_window (int): Focal window size at first focal level
+ use_layerscale (bool): Whether use layerscale
+ layerscale_value (float): Initial layerscale value
+ use_postln (bool): Whether use layernorm after modulation
+ """
+
+ def __init__(
+ self,
+ dim,
+ out_dim,
+ input_resolution,
+ depth,
+ mlp_ratio=4.0,
+ drop=0.0,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ downsample_kernel=[],
+ use_checkpoint=False,
+ focal_level=1,
+ focal_window=1,
+ use_conv_embed=False,
+ use_layerscale=False,
+ layerscale_value=1e-4,
+ use_postln=False,
+ use_postln_in_modulation=False,
+ normalize_modulator=False,
+ ):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ FocalNetBlock(
+ dim=dim,
+ input_resolution=input_resolution,
+ mlp_ratio=mlp_ratio,
+ drop=drop,
+ drop_path=drop_path[i]
+ if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ focal_level=focal_level,
+ focal_window=focal_window,
+ use_layerscale=use_layerscale,
+ layerscale_value=layerscale_value,
+ use_postln=use_postln,
+ use_postln_in_modulation=use_postln_in_modulation,
+ normalize_modulator=normalize_modulator,
+ ) for i in range(depth)
+ ])
+
+ if downsample is not None:
+ self.downsample = downsample(
+ img_size=input_resolution,
+ patch_size=downsample_kernel,
+ in_chans=dim,
+ embed_dim=out_dim,
+ use_conv_embed=use_conv_embed,
+ norm_layer=norm_layer,
+ is_stem=False,
+ )
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+
+ if self.downsample is not None:
+ x = x.transpose(1, 2).reshape(x.shape[0], -1, H, W)
+ x, Ho, Wo = self.downsample(x)
+ else:
+ Ho, Wo = H, W
+ return x, Ho, Wo
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+
+class PatchEmbed(nn.Module):
+ r"""Image to Patch Embedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self,
+ img_size=(224, 224),
+ patch_size=[4, 4],
+ in_chans=3,
+ embed_dim=96,
+ use_conv_embed=False,
+ norm_layer=None,
+ is_stem=False):
+ super().__init__()
+ # patch_size = to_2tuple(patch_size)
+ patches_resolution = [
+ img_size[0] // patch_size[0], img_size[1] // patch_size[1]
+ ]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ if use_conv_embed:
+ # if we choose to use conv embedding, then we treat the stem and non-stem differently
+ if is_stem:
+ kernel_size = 7
+ padding = 2
+ stride = 4
+ else:
+ kernel_size = 3
+ padding = 1
+ stride = 2
+ self.proj = nn.Conv2d(in_chans,
+ embed_dim,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+ else:
+ self.proj = nn.Conv2d(in_chans,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=patch_size)
+
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+
+ x = self.proj(x)
+ H, W = x.shape[2:]
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x, H, W
+
+ def flops(self):
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (
+ self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+
+class FocalSVTR(nn.Module):
+ r"""Focal Modulation Networks (FocalNets)
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default [32, 128]
+ patch_size (int | tuple(int)): Patch size. Default: [4, 4]
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Focal Transformer layer.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ drop_rate (float): Dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level. Default: [1, 1, 1, 1]
+ focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1]
+ use_conv_embed (bool): Whether use convolutional embedding. We noted that using convolutional embedding usually improve the performance,
+ but we do not use it by default. Default: False
+ use_layerscale (bool): Whether use layerscale proposed in CaiT. Default: False
+ layerscale_value (float): Value for layer scale. Default: 1e-4
+ use_postln (bool): Whether use layernorm after modulation (it helps stablize training of large models)
+ """
+
+ def __init__(
+ self,
+ img_size=[32, 128],
+ patch_size=[4, 4],
+ out_channels=256,
+ out_char_num=25,
+ in_channels=3,
+ embed_dim=96,
+ depths=[3, 6, 3],
+ sub_k=[[2, 1], [2, 1], [1, 1]],
+ last_stage=False,
+ mlp_ratio=4.0,
+ drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ patch_norm=True,
+ use_checkpoint=False,
+ focal_levels=[6, 6, 6],
+ focal_windows=[3, 3, 3],
+ use_conv_embed=False,
+ use_layerscale=False,
+ layerscale_value=1e-4,
+ use_postln=False,
+ use_postln_in_modulation=False,
+ normalize_modulator=False,
+ feat2d=False,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.num_layers = len(depths)
+ embed_dim = [embed_dim * (2**i) for i in range(self.num_layers)]
+ self.feat2d = feat2d
+ self.embed_dim = embed_dim
+ self.patch_norm = patch_norm
+ self.num_features = embed_dim[-1]
+ self.mlp_ratio = mlp_ratio
+
+ self.patch_embed = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim[0] // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim[0] // 2,
+ out_channels=embed_dim[0],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ )
+
+ patches_resolution = [
+ img_size[0] // patch_size[0], img_size[1] // patch_size[1]
+ ]
+ self.patches_resolution = patches_resolution
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+
+ layer = BasicLayer(
+ dim=embed_dim[i_layer],
+ out_dim=embed_dim[i_layer + 1] if
+ (i_layer < self.num_layers - 1) else None,
+ input_resolution=patches_resolution,
+ depth=depths[i_layer],
+ mlp_ratio=self.mlp_ratio,
+ drop=drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchEmbed if
+ (i_layer < self.num_layers - 1) else None,
+ downsample_kernel=sub_k[i_layer],
+ focal_level=focal_levels[i_layer],
+ focal_window=focal_windows[i_layer],
+ use_conv_embed=use_conv_embed,
+ use_checkpoint=use_checkpoint,
+ use_layerscale=use_layerscale,
+ layerscale_value=layerscale_value,
+ use_postln=use_postln,
+ use_postln_in_modulation=use_postln_in_modulation,
+ normalize_modulator=normalize_modulator,
+ )
+ patches_resolution = [
+ patches_resolution[0] // sub_k[i_layer][0],
+ patches_resolution[1] // sub_k[i_layer][1]
+ ]
+ self.layers.append(layer)
+ self.out_channels = self.num_features
+ self.last_stage = last_stage
+ if last_stage:
+ self.out_channels = out_channels
+ self.last_conv = nn.Linear(self.num_features,
+ self.out_channels,
+ bias=False)
+ self.hardswish = nn.Hardswish()
+ self.dropout = nn.Dropout(p=0.1)
+ # self.avg_pool = nn.AdaptiveAvgPool2d([1, out_char_num])
+ # self.last_conv = nn.Conv2d(
+ # in_channels=self.num_features,
+ # out_channels=self.out_channels,
+ # kernel_size=1,
+ # stride=1,
+ # padding=0,
+ # bias=False,
+ # )
+ # self.hardswish = nn.Hardswish()
+ # self.dropout = nn.Dropout(p=0.1)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight,
+ mode='fan_out',
+ nonlinearity='relu')
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'patch_embed', 'downsample'}
+
+ def forward(self, x):
+ if len(x.shape) == 5:
+ x = x.flatten(0, 1)
+ x = self.patch_embed(x)
+ H, W = x.shape[2:]
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x, H, W = layer(x, H, W)
+
+ if self.feat2d:
+ x = x.transpose(1, 2).reshape(-1, self.num_features, H, W)
+
+ if self.last_stage:
+
+ x = x.reshape(-1, H, W, self.num_features).mean(1)
+ x = self.last_conv(x)
+ x = self.hardswish(x)
+ x = self.dropout(x)
+ # x = self.avg_pool(x.transpose(1, 2).reshape(-1, self.num_features, H, W))
+ # x = self.last_conv(x)
+ # x = self.hardswish(x)
+ # x = self.dropout(x)
+ # x = x.flatten(2).transpose(1, 2)
+ return x
+
+ def flops(self):
+ flops = 0
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += self.num_features * self.patches_resolution[
+ 0] * self.patches_resolution[1] // (2**self.num_layers)
+ return flops
diff --git a/openrec/modeling/encoders/nrtr_encoder.py b/openrec/modeling/encoders/nrtr_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..488dbd56a5057d42d0a1fe98a74647769561852f
--- /dev/null
+++ b/openrec/modeling/encoders/nrtr_encoder.py
@@ -0,0 +1,28 @@
+from torch import nn
+
+
+class NRTREncoder(nn.Module):
+
+ def __init__(self, in_channels):
+ super(NRTREncoder, self).__init__()
+ self.out_channels = 512 # 64*H
+ self.block = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=32,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ), nn.ReLU(), nn.BatchNorm2d(32),
+ nn.Conv2d(
+ in_channels=32,
+ out_channels=64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ), nn.ReLU(), nn.BatchNorm2d(64))
+
+ def forward(self, images):
+ x = self.block(images)
+ x = x.permute(0, 3, 2, 1).flatten(2) # B, W, H*C
+ return x
diff --git a/openrec/modeling/encoders/rec_hgnet.py b/openrec/modeling/encoders/rec_hgnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f90b817c6e2e6901303daeec0a134af72b8b4c8b
--- /dev/null
+++ b/openrec/modeling/encoders/rec_hgnet.py
@@ -0,0 +1,346 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ConvBNAct(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ groups=1,
+ use_act=True):
+ super().__init__()
+ self.use_act = use_act
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ bias=False,
+ )
+ self.bn = nn.BatchNorm2d(out_channels)
+ if self.use_act:
+ self.act = nn.ReLU()
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ if self.use_act:
+ x = self.act(x)
+ return x
+
+
+class ESEModule(nn.Module):
+
+ def __init__(self, channels):
+ super().__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.conv = nn.Conv2d(
+ in_channels=channels,
+ out_channels=channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ identity = x
+ x = self.avg_pool(x)
+ x = self.conv(x)
+ x = self.sigmoid(x)
+ return x * identity
+
+
+class HG_Block(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ mid_channels,
+ out_channels,
+ layer_num,
+ identity=False,
+ ):
+ super().__init__()
+ self.identity = identity
+
+ self.layers = nn.ModuleList()
+ self.layers.append(
+ ConvBNAct(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=3,
+ stride=1,
+ ))
+ for _ in range(layer_num - 1):
+ self.layers.append(
+ ConvBNAct(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ kernel_size=3,
+ stride=1,
+ ))
+
+ # feature aggregation
+ total_channels = in_channels + layer_num * mid_channels
+ self.aggregation_conv = ConvBNAct(
+ in_channels=total_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ )
+ self.att = ESEModule(out_channels)
+
+ def forward(self, x):
+ identity = x
+ output = []
+ output.append(x)
+ for layer in self.layers:
+ x = layer(x)
+ output.append(x)
+ x = torch.cat(output, dim=1)
+ x = self.aggregation_conv(x)
+ x = self.att(x)
+ if self.identity:
+ x += identity
+ return x
+
+
+class HG_Stage(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ mid_channels,
+ out_channels,
+ block_num,
+ layer_num,
+ downsample=True,
+ stride=[2, 1],
+ ):
+ super().__init__()
+ self.downsample = downsample
+ if downsample:
+ self.downsample = ConvBNAct(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=3,
+ stride=stride,
+ groups=in_channels,
+ use_act=False,
+ )
+
+ blocks_list = []
+ blocks_list.append(
+ HG_Block(in_channels,
+ mid_channels,
+ out_channels,
+ layer_num,
+ identity=False))
+ for _ in range(block_num - 1):
+ blocks_list.append(
+ HG_Block(out_channels,
+ mid_channels,
+ out_channels,
+ layer_num,
+ identity=True))
+ self.blocks = nn.Sequential(*blocks_list)
+
+ def forward(self, x):
+ if self.downsample:
+ x = self.downsample(x)
+ x = self.blocks(x)
+ return x
+
+
+class PPHGNet(nn.Module):
+ """
+ PPHGNet
+ Args:
+ stem_channels: list. Stem channel list of PPHGNet.
+ stage_config: dict. The configuration of each stage of PPHGNet. such as the number of channels, stride, etc.
+ layer_num: int. Number of layers of HG_Block.
+ use_last_conv: boolean. Whether to use a 1x1 convolutional layer before the classification layer.
+ class_expand: int=2048. Number of channels for the last 1x1 convolutional layer.
+ dropout_prob: float. Parameters of dropout, 0.0 means dropout is not used.
+ class_num: int=1000. The number of classes.
+ Returns:
+ model: nn.Layer. Specific PPHGNet model depends on args.
+ """
+
+ def __init__(
+ self,
+ stem_channels,
+ stage_config,
+ layer_num,
+ in_channels=3,
+ det=False,
+ out_indices=None,
+ ):
+ super().__init__()
+ self.det = det
+ self.out_indices = out_indices if out_indices is not None else [
+ 0, 1, 2, 3
+ ]
+
+ # stem
+ stem_channels.insert(0, in_channels)
+ self.stem = nn.Sequential(*[
+ ConvBNAct(
+ in_channels=stem_channels[i],
+ out_channels=stem_channels[i + 1],
+ kernel_size=3,
+ stride=2 if i == 0 else 1,
+ ) for i in range(len(stem_channels) - 1)
+ ])
+
+ if self.det:
+ self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ # stages
+ self.stages = nn.ModuleList()
+ self.out_channels = []
+ for block_id, k in enumerate(stage_config):
+ (
+ in_channels,
+ mid_channels,
+ out_channels,
+ block_num,
+ downsample,
+ stride,
+ ) = stage_config[k]
+ self.stages.append(
+ HG_Stage(
+ in_channels,
+ mid_channels,
+ out_channels,
+ block_num,
+ layer_num,
+ downsample,
+ stride,
+ ))
+ if block_id in self.out_indices:
+ self.out_channels.append(out_channels)
+
+ if not self.det:
+ self.out_channels = stage_config['stage4'][2]
+
+ self._init_weights()
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x):
+ x = self.stem(x)
+ if self.det:
+ x = self.pool(x)
+
+ out = []
+ for i, stage in enumerate(self.stages):
+ x = stage(x)
+ if self.det and i in self.out_indices:
+ out.append(x)
+ if self.det:
+ return out
+
+ if self.training:
+ x = F.adaptive_avg_pool2d(x, [1, 40])
+ else:
+ x = F.avg_pool2d(x, [3, 2])
+ return x
+
+
+def PPHGNet_tiny(pretrained=False, use_ssld=False, **kwargs):
+ """
+ PPHGNet_tiny
+ Args:
+ pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
+ If str, means the path of the pretrained model.
+ use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
+ Returns:
+ model: nn.Layer. Specific `PPHGNet_tiny` model depends on args.
+ """
+ stage_config = {
+ # in_channels, mid_channels, out_channels, blocks, downsample
+ 'stage1': [96, 96, 224, 1, False, [2, 1]],
+ 'stage2': [224, 128, 448, 1, True, [1, 2]],
+ 'stage3': [448, 160, 512, 2, True, [2, 1]],
+ 'stage4': [512, 192, 768, 1, True, [2, 1]],
+ }
+
+ model = PPHGNet(stem_channels=[48, 48, 96],
+ stage_config=stage_config,
+ layer_num=5,
+ **kwargs)
+ return model
+
+
+def PPHGNet_small(pretrained=False, use_ssld=False, det=False, **kwargs):
+ """
+ PPHGNet_small
+ Args:
+ pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
+ If str, means the path of the pretrained model.
+ use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
+ Returns:
+ model: nn.Layer. Specific `PPHGNet_small` model depends on args.
+ """
+ stage_config_det = {
+ # in_channels, mid_channels, out_channels, blocks, downsample
+ 'stage1': [128, 128, 256, 1, False, 2],
+ 'stage2': [256, 160, 512, 1, True, 2],
+ 'stage3': [512, 192, 768, 2, True, 2],
+ 'stage4': [768, 224, 1024, 1, True, 2],
+ }
+
+ stage_config_rec = {
+ # in_channels, mid_channels, out_channels, blocks, downsample
+ 'stage1': [128, 128, 256, 1, True, [2, 1]],
+ 'stage2': [256, 160, 512, 1, True, [1, 2]],
+ 'stage3': [512, 192, 768, 2, True, [2, 1]],
+ 'stage4': [768, 224, 1024, 1, True, [2, 1]],
+ }
+
+ model = PPHGNet(stem_channels=[64, 64, 128],
+ stage_config=stage_config_det if det else stage_config_rec,
+ layer_num=6,
+ det=det,
+ **kwargs)
+ return model
+
+
+def PPHGNet_base(pretrained=False, use_ssld=True, **kwargs):
+ """
+ PPHGNet_base
+ Args:
+ pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
+ If str, means the path of the pretrained model.
+ use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
+ Returns:
+ model: nn.Layer. Specific `PPHGNet_base` model depends on args.
+ """
+ stage_config = {
+ # in_channels, mid_channels, out_channels, blocks, downsample
+ 'stage1': [160, 192, 320, 1, False, [2, 1]],
+ 'stage2': [320, 224, 640, 2, True, [1, 2]],
+ 'stage3': [640, 256, 960, 3, True, [2, 1]],
+ 'stage4': [960, 288, 1280, 2, True, [2, 1]],
+ }
+
+ model = PPHGNet(stem_channels=[96, 96, 160],
+ stage_config=stage_config,
+ layer_num=7,
+ **kwargs)
+ return model
diff --git a/openrec/modeling/encoders/rec_lcnetv3.py b/openrec/modeling/encoders/rec_lcnetv3.py
new file mode 100644
index 0000000000000000000000000000000000000000..962346170acdc0342a5c41c078713f12505b1a82
--- /dev/null
+++ b/openrec/modeling/encoders/rec_lcnetv3.py
@@ -0,0 +1,488 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from openrec.modeling.common import Activation
+
+NET_CONFIG_det = {
+ 'blocks2':
+ # k, in_c, out_c, s, use_se
+ [[3, 16, 32, 1, False]],
+ 'blocks3': [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
+ 'blocks4': [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
+ 'blocks5': [
+ [3, 128, 256, 2, False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ ],
+ 'blocks6': [
+ [5, 256, 512, 2, True],
+ [5, 512, 512, 1, True],
+ [5, 512, 512, 1, False],
+ [5, 512, 512, 1, False],
+ ],
+}
+
+NET_CONFIG_rec = {
+ 'blocks2':
+ # k, in_c, out_c, s, use_se
+ [[3, 16, 32, 1, False]],
+ 'blocks3': [[3, 32, 64, 1, False], [3, 64, 64, 1, False]],
+ 'blocks4': [[3, 64, 128, (2, 1), False], [3, 128, 128, 1, False]],
+ 'blocks5': [
+ [3, 128, 256, (1, 2), False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ ],
+ 'blocks6': [
+ [5, 256, 512, (2, 1), True],
+ [5, 512, 512, 1, True],
+ [5, 512, 512, (2, 1), False],
+ [5, 512, 512, 1, False],
+ ],
+}
+
+
+def make_divisible(v, divisor=16, min_value=None):
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class LearnableAffineBlock(nn.Module):
+
+ def __init__(self,
+ scale_value=1.0,
+ bias_value=0.0,
+ lr_mult=1.0,
+ lab_lr=0.1):
+ super().__init__()
+ self.scale = nn.Parameter(torch.Tensor([scale_value]))
+ self.bias = nn.Parameter(torch.Tensor([bias_value]))
+
+ def forward(self, x):
+ return self.scale * x + self.bias
+
+
+class ConvBNLayer(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ groups=1,
+ lr_mult=1.0):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ bias=False,
+ )
+
+ self.bn = nn.BatchNorm2d(out_channels)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class Act(nn.Module):
+
+ def __init__(self, act='hard_swish', lr_mult=1.0, lab_lr=0.1):
+ super().__init__()
+ assert act in ['hard_swish', 'relu']
+ self.act = Activation(act)
+ self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
+
+ def forward(self, x):
+ return self.lab(self.act(x))
+
+
+class LearnableRepLayer(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ num_conv_branches=1,
+ lr_mult=1.0,
+ lab_lr=0.1,
+ ):
+ super().__init__()
+ self.is_repped = False
+ self.groups = groups
+ self.stride = stride
+ self.kernel_size = kernel_size
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_conv_branches = num_conv_branches
+ self.padding = (kernel_size - 1) // 2
+
+ self.identity = (nn.BatchNorm2d(in_channels) if
+ out_channels == in_channels and stride == 1 else None)
+
+ self.conv_kxk = nn.ModuleList([
+ ConvBNLayer(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ groups=groups,
+ lr_mult=lr_mult,
+ ) for _ in range(self.num_conv_branches)
+ ])
+
+ self.conv_1x1 = (ConvBNLayer(in_channels,
+ out_channels,
+ 1,
+ stride,
+ groups=groups,
+ lr_mult=lr_mult)
+ if kernel_size > 1 else None)
+
+ self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
+ self.act = Act(lr_mult=lr_mult, lab_lr=lab_lr)
+
+ def forward(self, x):
+ # for export
+ if self.is_repped:
+ out = self.lab(self.reparam_conv(x))
+ if self.stride != 2:
+ out = self.act(out)
+ return out
+
+ out = 0
+ if self.identity is not None:
+ out += self.identity(x)
+
+ if self.conv_1x1 is not None:
+ out += self.conv_1x1(x)
+
+ for conv in self.conv_kxk:
+ out += conv(x)
+
+ out = self.lab(out)
+ if self.stride != 2:
+ out = self.act(out)
+ return out
+
+ def rep(self):
+ if self.is_repped:
+ return
+ kernel, bias = self._get_kernel_bias()
+ self.reparam_conv = nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ groups=self.groups,
+ )
+ self.reparam_conv.weight.data = kernel
+ self.reparam_conv.bias.data = bias
+ self.is_repped = True
+
+ def _pad_kernel_1x1_to_kxk(self, kernel1x1, pad):
+ if not isinstance(kernel1x1, torch.Tensor):
+ return 0
+ else:
+ return nn.functional.pad(kernel1x1, [pad, pad, pad, pad])
+
+ def _get_kernel_bias(self):
+ kernel_conv_1x1, bias_conv_1x1 = self._fuse_bn_tensor(self.conv_1x1)
+ kernel_conv_1x1 = self._pad_kernel_1x1_to_kxk(kernel_conv_1x1,
+ self.kernel_size // 2)
+
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
+
+ kernel_conv_kxk = 0
+ bias_conv_kxk = 0
+ for conv in self.conv_kxk:
+ kernel, bias = self._fuse_bn_tensor(conv)
+ kernel_conv_kxk += kernel
+ bias_conv_kxk += bias
+
+ kernel_reparam = kernel_conv_kxk + kernel_conv_1x1 + kernel_identity
+ bias_reparam = bias_conv_kxk + bias_conv_1x1 + bias_identity
+ return kernel_reparam, bias_reparam
+
+ def _fuse_bn_tensor(self, branch):
+ if not branch:
+ return 0, 0
+ elif isinstance(branch, ConvBNLayer):
+ kernel = branch.conv.weight
+ running_mean = branch.bn.running_mean
+ running_var = branch.bn.running_var
+ gamma = branch.bn.weight
+ beta = branch.bn.bias
+ eps = branch.bn.eps
+ else:
+ assert isinstance(branch, nn.BatchNorm2d)
+ if not hasattr(self, 'id_tensor'):
+ input_dim = self.in_channels // self.groups
+ kernel_value = torch.zeros(
+ (self.in_channels, input_dim, self.kernel_size,
+ self.kernel_size),
+ dtype=branch.weight.dtype,
+ )
+ for i in range(self.in_channels):
+ kernel_value[i, i % input_dim, self.kernel_size // 2,
+ self.kernel_size // 2] = 1
+ self.id_tensor = kernel_value
+ kernel = self.id_tensor
+ running_mean = branch.running_mean
+ running_var = branch.running_var
+ gamma = branch.weight
+ beta = branch.bias
+ eps = branch.eps
+ std = (running_var + eps).sqrt()
+ t = (gamma / std).reshape((-1, 1, 1, 1))
+ return kernel * t, beta - running_mean * gamma / std
+
+
+class SELayer(nn.Module):
+
+ def __init__(self, channel, reduction=4, lr_mult=1.0):
+ super().__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = nn.Conv2d(
+ in_channels=channel,
+ out_channels=channel // reduction,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.relu = nn.ReLU()
+ self.conv2 = nn.Conv2d(
+ in_channels=channel // reduction,
+ out_channels=channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.hardsigmoid = Activation('hard_sigmoid')
+
+ def forward(self, x):
+ identity = x
+ x = self.avg_pool(x)
+ x = self.conv1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.hardsigmoid(x)
+ x = x * identity
+ return x
+
+
+class LCNetV3Block(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ dw_size,
+ use_se=False,
+ conv_kxk_num=4,
+ lr_mult=1.0,
+ lab_lr=0.1,
+ ):
+ super().__init__()
+ self.use_se = use_se
+ self.dw_conv = LearnableRepLayer(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=dw_size,
+ stride=stride,
+ groups=in_channels,
+ num_conv_branches=conv_kxk_num,
+ lr_mult=lr_mult,
+ lab_lr=lab_lr,
+ )
+ if use_se:
+ self.se = SELayer(in_channels, lr_mult=lr_mult)
+ self.pw_conv = LearnableRepLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ num_conv_branches=conv_kxk_num,
+ lr_mult=lr_mult,
+ lab_lr=lab_lr,
+ )
+
+ def forward(self, x):
+ x = self.dw_conv(x)
+ if self.use_se:
+ x = self.se(x)
+ x = self.pw_conv(x)
+ return x
+
+
+class PPLCNetV3(nn.Module):
+
+ def __init__(self,
+ scale=1.0,
+ conv_kxk_num=4,
+ lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+ lab_lr=0.1,
+ det=False,
+ **kwargs):
+ super().__init__()
+ self.scale = scale
+ self.lr_mult_list = lr_mult_list
+ self.det = det
+
+ self.net_config = NET_CONFIG_det if self.det else NET_CONFIG_rec
+
+ assert isinstance(
+ self.lr_mult_list,
+ (list, tuple
+ )), 'lr_mult_list should be in (list, tuple) but got {}'.format(
+ type(self.lr_mult_list))
+ assert len(self.lr_mult_list
+ ) == 6, 'lr_mult_list length should be 6 but got {}'.format(
+ len(self.lr_mult_list))
+
+ self.conv1 = ConvBNLayer(
+ in_channels=3,
+ out_channels=make_divisible(16 * scale),
+ kernel_size=3,
+ stride=2,
+ lr_mult=self.lr_mult_list[0],
+ )
+
+ self.blocks2 = nn.Sequential(*[
+ LCNetV3Block(
+ in_channels=make_divisible(in_c * scale),
+ out_channels=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ conv_kxk_num=conv_kxk_num,
+ lr_mult=self.lr_mult_list[1],
+ lab_lr=lab_lr,
+ ) for i, (k, in_c, out_c, s,
+ se) in enumerate(self.net_config['blocks2'])
+ ])
+
+ self.blocks3 = nn.Sequential(*[
+ LCNetV3Block(
+ in_channels=make_divisible(in_c * scale),
+ out_channels=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ conv_kxk_num=conv_kxk_num,
+ lr_mult=self.lr_mult_list[2],
+ lab_lr=lab_lr,
+ ) for i, (k, in_c, out_c, s,
+ se) in enumerate(self.net_config['blocks3'])
+ ])
+
+ self.blocks4 = nn.Sequential(*[
+ LCNetV3Block(
+ in_channels=make_divisible(in_c * scale),
+ out_channels=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ conv_kxk_num=conv_kxk_num,
+ lr_mult=self.lr_mult_list[3],
+ lab_lr=lab_lr,
+ ) for i, (k, in_c, out_c, s,
+ se) in enumerate(self.net_config['blocks4'])
+ ])
+
+ self.blocks5 = nn.Sequential(*[
+ LCNetV3Block(
+ in_channels=make_divisible(in_c * scale),
+ out_channels=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ conv_kxk_num=conv_kxk_num,
+ lr_mult=self.lr_mult_list[4],
+ lab_lr=lab_lr,
+ ) for i, (k, in_c, out_c, s,
+ se) in enumerate(self.net_config['blocks5'])
+ ])
+
+ self.blocks6 = nn.Sequential(*[
+ LCNetV3Block(
+ in_channels=make_divisible(in_c * scale),
+ out_channels=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ conv_kxk_num=conv_kxk_num,
+ lr_mult=self.lr_mult_list[5],
+ lab_lr=lab_lr,
+ ) for i, (k, in_c, out_c, s,
+ se) in enumerate(self.net_config['blocks6'])
+ ])
+ self.out_channels = make_divisible(512 * scale)
+
+ if self.det:
+ mv_c = [16, 24, 56, 480]
+ self.out_channels = [
+ make_divisible(self.net_config['blocks3'][-1][2] * scale),
+ make_divisible(self.net_config['blocks4'][-1][2] * scale),
+ make_divisible(self.net_config['blocks5'][-1][2] * scale),
+ make_divisible(self.net_config['blocks6'][-1][2] * scale),
+ ]
+
+ self.layer_list = nn.ModuleList([
+ nn.Conv2d(self.out_channels[0], int(mv_c[0] * scale), 1, 1, 0),
+ nn.Conv2d(self.out_channels[1], int(mv_c[1] * scale), 1, 1, 0),
+ nn.Conv2d(self.out_channels[2], int(mv_c[2] * scale), 1, 1, 0),
+ nn.Conv2d(self.out_channels[3], int(mv_c[3] * scale), 1, 1, 0),
+ ])
+ self.out_channels = [
+ int(mv_c[0] * scale),
+ int(mv_c[1] * scale),
+ int(mv_c[2] * scale),
+ int(mv_c[3] * scale),
+ ]
+
+ def forward(self, x):
+ out_list = []
+ x = self.conv1(x)
+
+ x = self.blocks2(x)
+ x = self.blocks3(x)
+ out_list.append(x)
+ x = self.blocks4(x)
+ out_list.append(x)
+ x = self.blocks5(x)
+ out_list.append(x)
+ x = self.blocks6(x)
+ out_list.append(x)
+
+ if self.det:
+ out_list[0] = self.layer_list[0](out_list[0])
+ out_list[1] = self.layer_list[1](out_list[1])
+ out_list[2] = self.layer_list[2](out_list[2])
+ out_list[3] = self.layer_list[3](out_list[3])
+ return out_list
+
+ if self.training:
+ x = F.adaptive_avg_pool2d(x, [1, 40])
+ else:
+ x = F.avg_pool2d(x, [3, 2])
+ return x
diff --git a/openrec/modeling/encoders/rec_mobilenet_v3.py b/openrec/modeling/encoders/rec_mobilenet_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..77736c03c5e91beff079df0cb85392f1c128f4cd
--- /dev/null
+++ b/openrec/modeling/encoders/rec_mobilenet_v3.py
@@ -0,0 +1,132 @@
+import torch.nn as nn
+
+from .det_mobilenet_v3 import ConvBNLayer, ResidualUnit, make_divisible
+
+
+class MobileNetV3(nn.Module):
+
+ def __init__(self,
+ in_channels=3,
+ model_name='small',
+ scale=0.5,
+ large_stride=None,
+ small_stride=None,
+ **kwargs):
+ super(MobileNetV3, self).__init__()
+ if small_stride is None:
+ small_stride = [2, 2, 2, 2]
+ if large_stride is None:
+ large_stride = [1, 2, 2, 2]
+
+ assert isinstance(
+ large_stride,
+ list), 'large_stride type must ' 'be list but got {}'.format(
+ type(large_stride))
+ assert isinstance(
+ small_stride,
+ list), 'small_stride type must ' 'be list but got {}'.format(
+ type(small_stride))
+ assert len(
+ large_stride
+ ) == 4, 'large_stride length must be ' '4 but got {}'.format(
+ len(large_stride))
+ assert len(
+ small_stride
+ ) == 4, 'small_stride length must be ' '4 but got {}'.format(
+ len(small_stride))
+
+ if model_name == 'large':
+ cfg = [
+ # k, exp, c, se, nl, s,
+ [3, 16, 16, False, 'relu', large_stride[0]],
+ [3, 64, 24, False, 'relu', (large_stride[1], 1)],
+ [3, 72, 24, False, 'relu', 1],
+ [5, 72, 40, True, 'relu', (large_stride[2], 1)],
+ [5, 120, 40, True, 'relu', 1],
+ [5, 120, 40, True, 'relu', 1],
+ [3, 240, 80, False, 'hard_swish', 1],
+ [3, 200, 80, False, 'hard_swish', 1],
+ [3, 184, 80, False, 'hard_swish', 1],
+ [3, 184, 80, False, 'hard_swish', 1],
+ [3, 480, 112, True, 'hard_swish', 1],
+ [3, 672, 112, True, 'hard_swish', 1],
+ [5, 672, 160, True, 'hard_swish', (large_stride[3], 1)],
+ [5, 960, 160, True, 'hard_swish', 1],
+ [5, 960, 160, True, 'hard_swish', 1],
+ ]
+ cls_ch_squeeze = 960
+ elif model_name == 'small':
+ cfg = [
+ # k, exp, c, se, nl, s,
+ [3, 16, 16, True, 'relu', (small_stride[0], 1)],
+ [3, 72, 24, False, 'relu', (small_stride[1], 1)],
+ [3, 88, 24, False, 'relu', 1],
+ [5, 96, 40, True, 'hard_swish', (small_stride[2], 1)],
+ [5, 240, 40, True, 'hard_swish', 1],
+ [5, 240, 40, True, 'hard_swish', 1],
+ [5, 120, 48, True, 'hard_swish', 1],
+ [5, 144, 48, True, 'hard_swish', 1],
+ [5, 288, 96, True, 'hard_swish', (small_stride[3], 1)],
+ [5, 576, 96, True, 'hard_swish', 1],
+ [5, 576, 96, True, 'hard_swish', 1],
+ ]
+ cls_ch_squeeze = 576
+ else:
+ raise NotImplementedError('mode[' + model_name +
+ '_model] is not implemented!')
+
+ supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
+ assert scale in supported_scale, 'supported scales are {} but input scale is {}'.format(
+ supported_scale, scale)
+
+ inplanes = 16
+ # conv1
+ self.conv1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=make_divisible(inplanes * scale),
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=1,
+ if_act=True,
+ act='hard_swish',
+ )
+ i = 0
+ block_list = []
+ inplanes = make_divisible(inplanes * scale)
+ for k, exp, c, se, nl, s in cfg:
+ block_list.append(
+ ResidualUnit(
+ in_channels=inplanes,
+ mid_channels=make_divisible(scale * exp),
+ out_channels=make_divisible(scale * c),
+ kernel_size=k,
+ stride=s,
+ use_se=se,
+ act=nl,
+ name='conv' + str(i + 2),
+ ))
+ inplanes = make_divisible(scale * c)
+ i += 1
+ self.blocks = nn.Sequential(*block_list)
+
+ self.conv2 = ConvBNLayer(
+ in_channels=inplanes,
+ out_channels=make_divisible(scale * cls_ch_squeeze),
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1,
+ if_act=True,
+ act='hard_swish',
+ )
+
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+ self.out_channels = make_divisible(scale * cls_ch_squeeze)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.blocks(x)
+ x = self.conv2(x)
+ x = self.pool(x)
+ return x
diff --git a/openrec/modeling/encoders/rec_mv1_enhance.py b/openrec/modeling/encoders/rec_mv1_enhance.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a21a3186f567b0de9e35e7981784f32a9796557
--- /dev/null
+++ b/openrec/modeling/encoders/rec_mv1_enhance.py
@@ -0,0 +1,254 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from openrec.modeling.common import Activation
+
+
+class ConvBNLayer(nn.Module):
+
+ def __init__(
+ self,
+ num_channels,
+ filter_size,
+ num_filters,
+ stride,
+ padding,
+ num_groups=1,
+ act='hard_swish',
+ ):
+ super(ConvBNLayer, self).__init__()
+ self.act = act
+ self._conv = nn.Conv2d(
+ in_channels=num_channels,
+ out_channels=num_filters,
+ kernel_size=filter_size,
+ stride=stride,
+ padding=padding,
+ groups=num_groups,
+ bias=False,
+ )
+
+ self._batch_norm = nn.BatchNorm2d(num_filters, )
+ if self.act is not None:
+ self._act = Activation(act_type=act, inplace=True)
+
+ def forward(self, inputs):
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ if self.act is not None:
+ y = self._act(y)
+ return y
+
+
+class DepthwiseSeparable(nn.Module):
+
+ def __init__(
+ self,
+ num_channels,
+ num_filters1,
+ num_filters2,
+ num_groups,
+ stride,
+ scale,
+ dw_size=3,
+ padding=1,
+ use_se=False,
+ ):
+ super(DepthwiseSeparable, self).__init__()
+ self._depthwise_conv = ConvBNLayer(
+ num_channels=num_channels,
+ num_filters=int(num_filters1 * scale),
+ filter_size=dw_size,
+ stride=stride,
+ padding=padding,
+ num_groups=int(num_groups * scale),
+ )
+ self._se = None
+ if use_se:
+ self._se = SEModule(int(num_filters1 * scale))
+ self._pointwise_conv = ConvBNLayer(
+ num_channels=int(num_filters1 * scale),
+ filter_size=1,
+ num_filters=int(num_filters2 * scale),
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, inputs):
+ y = self._depthwise_conv(inputs)
+ if self._se is not None:
+ y = self._se(y)
+ y = self._pointwise_conv(y)
+ return y
+
+
+class MobileNetV1Enhance(nn.Module):
+
+ def __init__(self,
+ in_channels=3,
+ scale=0.5,
+ last_conv_stride=1,
+ last_pool_type='max',
+ **kwargs):
+ super().__init__()
+ self.scale = scale
+ self.block_list = []
+
+ self.conv1 = ConvBNLayer(
+ num_channels=in_channels,
+ filter_size=3,
+ num_filters=int(32 * scale),
+ stride=2,
+ padding=1,
+ )
+
+ conv2_1 = DepthwiseSeparable(
+ num_channels=int(32 * scale),
+ num_filters1=32,
+ num_filters2=64,
+ num_groups=32,
+ stride=1,
+ scale=scale,
+ )
+ self.block_list.append(conv2_1)
+
+ conv2_2 = DepthwiseSeparable(
+ num_channels=int(64 * scale),
+ num_filters1=64,
+ num_filters2=128,
+ num_groups=64,
+ stride=1,
+ scale=scale,
+ )
+ self.block_list.append(conv2_2)
+
+ conv3_1 = DepthwiseSeparable(
+ num_channels=int(128 * scale),
+ num_filters1=128,
+ num_filters2=128,
+ num_groups=128,
+ stride=1,
+ scale=scale,
+ )
+ self.block_list.append(conv3_1)
+
+ conv3_2 = DepthwiseSeparable(
+ num_channels=int(128 * scale),
+ num_filters1=128,
+ num_filters2=256,
+ num_groups=128,
+ stride=(2, 1),
+ scale=scale,
+ )
+ self.block_list.append(conv3_2)
+
+ conv4_1 = DepthwiseSeparable(
+ num_channels=int(256 * scale),
+ num_filters1=256,
+ num_filters2=256,
+ num_groups=256,
+ stride=1,
+ scale=scale,
+ )
+ self.block_list.append(conv4_1)
+
+ conv4_2 = DepthwiseSeparable(
+ num_channels=int(256 * scale),
+ num_filters1=256,
+ num_filters2=512,
+ num_groups=256,
+ stride=(2, 1),
+ scale=scale,
+ )
+ self.block_list.append(conv4_2)
+
+ for _ in range(5):
+ conv5 = DepthwiseSeparable(
+ num_channels=int(512 * scale),
+ num_filters1=512,
+ num_filters2=512,
+ num_groups=512,
+ stride=1,
+ dw_size=5,
+ padding=2,
+ scale=scale,
+ use_se=False,
+ )
+ self.block_list.append(conv5)
+
+ conv5_6 = DepthwiseSeparable(
+ num_channels=int(512 * scale),
+ num_filters1=512,
+ num_filters2=1024,
+ num_groups=512,
+ stride=(2, 1),
+ dw_size=5,
+ padding=2,
+ scale=scale,
+ use_se=True,
+ )
+ self.block_list.append(conv5_6)
+
+ conv6 = DepthwiseSeparable(
+ num_channels=int(1024 * scale),
+ num_filters1=1024,
+ num_filters2=1024,
+ num_groups=1024,
+ stride=last_conv_stride,
+ dw_size=5,
+ padding=2,
+ use_se=True,
+ scale=scale,
+ )
+ self.block_list.append(conv6)
+
+ self.block_list = nn.Sequential(*self.block_list)
+ if last_pool_type == 'avg':
+ self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
+ else:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+ self.out_channels = int(1024 * scale)
+
+ def forward(self, inputs):
+ y = self.conv1(inputs)
+ y = self.block_list(y)
+ y = self.pool(y)
+ return y
+
+
+def hardsigmoid(x):
+ return F.relu6(x + 3.0, inplace=True) / 6.0
+
+
+class SEModule(nn.Module):
+
+ def __init__(self, channel, reduction=4):
+ super(SEModule, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = nn.Conv2d(
+ in_channels=channel,
+ out_channels=channel // reduction,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ )
+ self.conv2 = nn.Conv2d(
+ in_channels=channel // reduction,
+ out_channels=channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ )
+
+ def forward(self, inputs):
+ outputs = self.avg_pool(inputs)
+ outputs = self.conv1(outputs)
+ outputs = F.relu(outputs)
+ outputs = self.conv2(outputs)
+ outputs = hardsigmoid(outputs)
+ x = torch.mul(inputs, outputs)
+
+ return x
diff --git a/openrec/modeling/encoders/rec_nrtr_mtb.py b/openrec/modeling/encoders/rec_nrtr_mtb.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f9e9fa2ccf3981c9c706b8e8a067073d5995819
--- /dev/null
+++ b/openrec/modeling/encoders/rec_nrtr_mtb.py
@@ -0,0 +1,37 @@
+import torch
+from torch import nn
+
+
+class MTB(nn.Module):
+
+ def __init__(self, cnn_num, in_channels):
+ super(MTB, self).__init__()
+ self.block = nn.Sequential()
+ self.out_channels = in_channels
+ self.cnn_num = cnn_num
+ if self.cnn_num == 2:
+ for i in range(self.cnn_num):
+ self.block.add_module(
+ 'conv_{}'.format(i),
+ nn.Conv2d(
+ in_channels=in_channels if i == 0 else 32 *
+ (2**(i - 1)),
+ out_channels=32 * (2**i),
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+ self.block.add_module('relu_{}'.format(i), nn.ReLU())
+ self.block.add_module('bn_{}'.format(i),
+ nn.BatchNorm2d(32 * (2**i)))
+
+ def forward(self, images):
+ x = self.block(images)
+ if self.cnn_num == 2:
+ # (b, w, h, c)
+ x = x.permute(0, 3, 2, 1)
+ x_shape = x.shape
+ x = torch.reshape(
+ x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3]))
+ return x
diff --git a/openrec/modeling/encoders/rec_resnet_31.py b/openrec/modeling/encoders/rec_resnet_31.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b04cb414a6d1781946e61330c56708d03afa826
--- /dev/null
+++ b/openrec/modeling/encoders/rec_resnet_31.py
@@ -0,0 +1,213 @@
+import torch.nn as nn
+
+__all__ = ['ResNet31']
+
+
+def conv3x3(in_channel, out_channel, stride=1):
+ return nn.Conv2d(in_channel,
+ out_channel,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_channels, channels, stride=1, downsample=False):
+ super().__init__()
+ self.conv1 = conv3x3(in_channels, channels, stride)
+ self.bn1 = nn.BatchNorm2d(channels)
+ self.relu = nn.ReLU()
+ self.conv2 = conv3x3(channels, channels)
+ self.bn2 = nn.BatchNorm2d(channels)
+ self.downsample = downsample
+ if downsample:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_channels,
+ channels * self.expansion,
+ 1,
+ stride,
+ bias=False),
+ nn.BatchNorm2d(channels * self.expansion),
+ )
+ else:
+ self.downsample = nn.Sequential()
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet31(nn.Module):
+ """
+ Args:
+ in_channels (int): Number of channels of input image tensor.
+ layers (list[int]): List of BasicBlock number for each stage.
+ channels (list[int]): List of out_channels of Conv2d layer.
+ out_indices (None | Sequence[int]): Indices of output stages.
+ last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
+ """
+
+ def __init__(
+ self,
+ in_channels=3,
+ layers=[1, 2, 5, 3],
+ channels=[64, 128, 256, 256, 512, 512, 512],
+ out_indices=None,
+ last_stage_pool=False,
+ ):
+ super(ResNet31, self).__init__()
+ assert isinstance(in_channels, int)
+ assert isinstance(last_stage_pool, bool)
+
+ self.out_indices = out_indices
+ self.last_stage_pool = last_stage_pool
+
+ # conv 1 (Conv Conv)
+ self.conv1_1 = nn.Conv2d(in_channels,
+ channels[0],
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.bn1_1 = nn.BatchNorm2d(channels[0])
+ self.relu1_1 = nn.ReLU(inplace=True)
+
+ self.conv1_2 = nn.Conv2d(channels[0],
+ channels[1],
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.bn1_2 = nn.BatchNorm2d(channels[1])
+ self.relu1_2 = nn.ReLU(inplace=True)
+
+ # conv 2 (Max-pooling, Residual block, Conv)
+ self.pool2 = nn.MaxPool2d(kernel_size=2,
+ stride=2,
+ padding=0,
+ ceil_mode=True)
+ self.block2 = self._make_layer(channels[1], channels[2], layers[0])
+ self.conv2 = nn.Conv2d(channels[2],
+ channels[2],
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.bn2 = nn.BatchNorm2d(channels[2])
+ self.relu2 = nn.ReLU(inplace=True)
+
+ # conv 3 (Max-pooling, Residual block, Conv)
+ self.pool3 = nn.MaxPool2d(kernel_size=2,
+ stride=2,
+ padding=0,
+ ceil_mode=True)
+ self.block3 = self._make_layer(channels[2], channels[3], layers[1])
+ self.conv3 = nn.Conv2d(channels[3],
+ channels[3],
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.bn3 = nn.BatchNorm2d(channels[3])
+ self.relu3 = nn.ReLU(inplace=True)
+
+ # conv 4 (Max-pooling, Residual block, Conv)
+ self.pool4 = nn.MaxPool2d(kernel_size=(2, 1),
+ stride=(2, 1),
+ padding=0,
+ ceil_mode=True)
+ self.block4 = self._make_layer(channels[3], channels[4], layers[2])
+ self.conv4 = nn.Conv2d(channels[4],
+ channels[4],
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.bn4 = nn.BatchNorm2d(channels[4])
+ self.relu4 = nn.ReLU(inplace=True)
+
+ # conv 5 ((Max-pooling), Residual block, Conv)
+ self.pool5 = None
+ if self.last_stage_pool:
+ self.pool5 = nn.MaxPool2d(kernel_size=2,
+ stride=2,
+ padding=0,
+ ceil_mode=True)
+ self.block5 = self._make_layer(channels[4], channels[5], layers[3])
+ self.conv5 = nn.Conv2d(channels[5],
+ channels[5],
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.bn5 = nn.BatchNorm2d(channels[5])
+ self.relu5 = nn.ReLU(inplace=True)
+
+ self.out_channels = channels[-1]
+
+ def _make_layer(self, input_channels, output_channels, blocks):
+ layers = []
+ for _ in range(blocks):
+ downsample = None
+ if input_channels != output_channels:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ input_channels,
+ output_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(output_channels),
+ )
+
+ layers.append(
+ BasicBlock(input_channels,
+ output_channels,
+ downsample=downsample))
+ input_channels = output_channels
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1_1(x)
+ x = self.bn1_1(x)
+ x = self.relu1_1(x)
+
+ x = self.conv1_2(x)
+ x = self.bn1_2(x)
+ x = self.relu1_2(x)
+
+ outs = []
+ for i in range(4):
+ layer_index = i + 2
+ pool_layer = getattr(self, 'pool{}'.format(layer_index))
+ block_layer = getattr(self, 'block{}'.format(layer_index))
+ conv_layer = getattr(self, 'conv{}'.format(layer_index))
+ bn_layer = getattr(self, 'bn{}'.format(layer_index))
+ relu_layer = getattr(self, 'relu{}'.format(layer_index))
+
+ if pool_layer is not None:
+ x = pool_layer(x)
+ x = block_layer(x)
+ x = conv_layer(x)
+ x = bn_layer(x)
+ x = relu_layer(x)
+
+ outs.append(x)
+
+ if self.out_indices is not None:
+ return tuple([outs[i] for i in self.out_indices])
+
+ return x
diff --git a/openrec/modeling/encoders/rec_resnet_45.py b/openrec/modeling/encoders/rec_resnet_45.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e4e2c1d1f736a91229511e59d94979d08cb8f79
--- /dev/null
+++ b/openrec/modeling/encoders/rec_resnet_45.py
@@ -0,0 +1,183 @@
+import math
+
+import numpy as np
+import torch.nn as nn
+
+from openrec.modeling.common import Block
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding."""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv1x1(inplanes, planes)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet45(nn.Module):
+
+ def __init__(
+ self,
+ in_channels=3,
+ block=BasicBlock,
+ layers=[3, 4, 6, 6, 3],
+ strides=[2, 1, 2, 1, 1],
+ last_stage=False,
+ out_channels=256,
+ trans_layer=0,
+ out_dim=384,
+ feat2d=True,
+ return_list=False,
+ ):
+ super(ResNet45, self).__init__()
+ self.inplanes = 32
+ self.conv1 = nn.Conv2d(in_channels,
+ 32,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(32)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, 32, layers[0], stride=strides[0])
+ self.layer2 = self._make_layer(block, 64, layers[1], stride=strides[1])
+ self.layer3 = self._make_layer(block,
+ 128,
+ layers[2],
+ stride=strides[2])
+ self.layer4 = self._make_layer(block,
+ 256,
+ layers[3],
+ stride=strides[3])
+ self.layer5 = self._make_layer(block,
+ 512,
+ layers[4],
+ stride=strides[4])
+ self.out_channels = 512
+ self.feat2d = feat2d
+ self.return_list = return_list
+ if trans_layer > 0:
+ dpr = np.linspace(0, 0.1, trans_layer)
+ blocks = [nn.Linear(512, out_dim)] + [
+ Block(dim=out_dim,
+ num_heads=out_dim // 32,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ drop_path=dpr[i]) for i in range(trans_layer)
+ ]
+ self.trans_blocks = nn.Sequential(*blocks)
+ dim = out_dim
+ self.out_channels = out_dim
+ else:
+ self.trans_blocks = None
+ dim = 512
+ self.last_stage = last_stage
+ if last_stage:
+ self.out_channels = out_channels
+ self.last_conv = nn.Linear(dim, self.out_channels, bias=False)
+ self.hardswish = nn.Hardswish()
+ self.dropout = nn.Dropout(p=0.1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, mean=0, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False,
+ ),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+ x2 = self.layer2(x)
+ x3 = self.layer3(x2)
+ x4 = self.layer4(x3)
+ x5 = self.layer5(x4)
+
+ if self.return_list:
+ return [x2, x3, x4, x5]
+ x = x5
+ if self.trans_blocks is not None:
+ B, C, H, W = x.shape
+ x = self.trans_blocks(x.flatten(2, 3).transpose(1, 2))
+ x = x.transpose(1, 2).reshape(B, -1, H, W)
+
+ if self.last_stage:
+ x = x.mean(2).transpose(1, 2)
+ x = self.last_conv(x)
+ x = self.hardswish(x)
+ x = self.dropout(x)
+ elif not self.feat2d:
+ x = x.flatten(2).transpose(1, 2)
+ return x
diff --git a/openrec/modeling/encoders/rec_resnet_fpn.py b/openrec/modeling/encoders/rec_resnet_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a749503fd250b7c65ed84a4926477fdac204593e
--- /dev/null
+++ b/openrec/modeling/encoders/rec_resnet_fpn.py
@@ -0,0 +1,216 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+class ConvBNLayer(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel,
+ stride=1,
+ act='ReLU'):
+ super(ConvBNLayer, self).__init__()
+ self.act_flag = act
+ self.conv = nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=2 if stride == (1, 1) else kernel,
+ stride=stride,
+ padding=(kernel - 1) // 2,
+ dilation=2 if stride == (1, 1) else 1)
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.act = nn.ReLU(True)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ if self.act_flag != 'None':
+ x = self.act(x)
+ return x
+
+
+class Shortcut(nn.Module):
+
+ def __init__(self, in_channels, out_channels, stride, is_first=False):
+ super(Shortcut, self).__init__()
+ self.use_conv = True
+ if in_channels != out_channels or stride != 1 or is_first is True:
+ if stride == (1, 1):
+ self.conv = ConvBNLayer(in_channels, out_channels, 1, 1)
+ else:
+ self.conv = ConvBNLayer(in_channels, out_channels, 1, stride)
+ else:
+ self.use_conv = False
+
+ def forward(self, x):
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class BottleneckBlock(nn.Module):
+
+ def __init__(self, in_channels, out_channels, stride):
+ super(BottleneckBlock, self).__init__()
+ self.conv0 = ConvBNLayer(in_channels, out_channels, kernel=1)
+ self.conv1 = ConvBNLayer(out_channels,
+ out_channels,
+ kernel=3,
+ stride=stride)
+ self.conv2 = ConvBNLayer(out_channels,
+ out_channels * 4,
+ kernel=1,
+ act='None')
+ self.short = Shortcut(in_channels, out_channels * 4, stride=stride)
+ self.out_channels = out_channels * 4
+ self.relu = nn.ReLU(True)
+
+ def forward(self, x):
+ y = self.conv0(x)
+ y = self.conv1(y)
+ y = self.conv2(y)
+ y = y + self.short(x)
+ y = self.relu(y)
+ return y
+
+
+class BasicBlock(nn.Module):
+
+ def __init__(self, in_channels, out_channels, stride, is_first):
+ super(BasicBlock, self).__init__()
+ self.conv0 = ConvBNLayer(in_channels,
+ out_channels,
+ kernel=3,
+ stride=stride)
+ self.conv1 = ConvBNLayer(out_channels,
+ out_channels,
+ kernel=3,
+ act='None')
+ self.short = Shortcut(in_channels, out_channels, stride, is_first)
+ self.out_chanels = out_channels
+ self.relu = nn.ReLU(True)
+
+ def forward(self, x):
+ y = self.conv0(x)
+ y = self.conv1(y)
+ y = y + self.short(x)
+ y = self.relu(y)
+ return y
+
+
+class ResNet_FPN(nn.Module):
+
+ def __init__(self, in_channels=1, layers=50, **kwargs):
+ super(ResNet_FPN, self).__init__()
+ supported_layers = {
+ 18: {
+ 'depth': [2, 2, 2, 2],
+ 'block_class': BasicBlock
+ },
+ 34: {
+ 'depth': [3, 4, 6, 3],
+ 'block_class': BasicBlock
+ },
+ 50: {
+ 'depth': [3, 4, 6, 3],
+ 'block_class': BottleneckBlock
+ },
+ 101: {
+ 'depth': [3, 4, 23, 3],
+ 'block_class': BottleneckBlock
+ },
+ 152: {
+ 'depth': [3, 8, 36, 3],
+ 'block_class': BottleneckBlock
+ }
+ }
+ stride_list = [(2, 2), (
+ 2,
+ 2,
+ ), (1, 1), (1, 1)]
+ num_filters = [64, 128, 256, 512]
+ self.depth = supported_layers[layers]['depth']
+ self.F = []
+ # print(f"in_channels:{in_channels}")
+ self.conv = ConvBNLayer(in_channels=in_channels,
+ out_channels=64,
+ kernel=7,
+ stride=2) #64*256 ->32*128
+
+ self.block_list = nn.ModuleList()
+ in_ch = 64
+ if layers >= 50:
+ for block in range(len(self.depth)):
+ for i in range(self.depth[block]):
+ self.block_list.append(
+ BottleneckBlock(
+ in_channels=in_ch,
+ out_channels=num_filters[block],
+ stride=stride_list[block] if i == 0 else 1))
+ in_ch = num_filters[block] * 4
+
+ else:
+ for block in range(len(self.depth)):
+ for i in range(self.depth[block]):
+ if i == 0 and block != 0:
+ stride = (2, 1)
+ else:
+ stride = (1, 1)
+ basic_block = BasicBlock(
+ in_channels=in_ch,
+ out_channels=num_filters[block],
+ stride=stride_list[block] if i == 0 else 1,
+ is_first=block == i == 0)
+ in_ch = basic_block.out_chanels
+ self.block_list.append(basic_block)
+
+ out_ch_list = [in_ch // 4, in_ch // 2, in_ch]
+ self.base_block = nn.ModuleList()
+ self.conv_trans = []
+ self.bn_block = []
+ for i in [-2, -3]:
+ in_channels = out_ch_list[i + 1] + out_ch_list[i]
+ self.base_block.append(
+ nn.Conv2d(in_channels, out_ch_list[i], kernel_size=1)) #进行升通道
+ self.base_block.append(
+ nn.Conv2d(out_ch_list[i],
+ out_ch_list[i],
+ kernel_size=3,
+ padding=1)) #进行合并
+ self.base_block.append(
+ nn.Sequential(nn.BatchNorm2d(out_ch_list[i]), nn.ReLU(True)))
+ self.base_block.append(nn.Conv2d(out_ch_list[i], 512, kernel_size=1))
+
+ self.out_channels = 512
+
+ def forward(self, x):
+
+ # print(f"before resnetfpn x.shape:{x.shape}")
+ x = self.conv(x)
+ fpn_list = []
+ F = []
+ for i in range(len(self.depth)):
+ fpn_list.append(np.sum(self.depth[:i + 1]))
+ for i, block in enumerate(self.block_list):
+ x = block(x)
+
+ for number in fpn_list:
+ if i + 1 == number:
+ F.append(x)
+ base = F[-1]
+
+ j = 0
+ for i, block in enumerate(self.base_block):
+ if i % 3 == 0 and i < 6:
+ j = j + 1
+ b, c, w, h = F[-j - 1].size()
+ if [w, h] == list(base.size()[2:]):
+ base = base
+ else:
+ base = self.conv_trans[j - 1](base)
+ base = self.bn_block[j - 1](base)
+ base = torch.cat([base, F[-j - 1]], dim=1)
+ base = block(base)
+
+ return base
diff --git a/openrec/modeling/encoders/rec_resnet_vd.py b/openrec/modeling/encoders/rec_resnet_vd.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fb62060e26336bac54407becbf57c2582e02599
--- /dev/null
+++ b/openrec/modeling/encoders/rec_resnet_vd.py
@@ -0,0 +1,252 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+from openrec.modeling.common import Activation
+
+
+class ConvBNLayer(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None,
+ ):
+ super(ConvBNLayer, self).__init__()
+ self.act = act
+ self.is_vd_mode = is_vd_mode
+ self._pool2d_avg = nn.AvgPool2d(kernel_size=stride,
+ stride=stride,
+ padding=0,
+ ceil_mode=False)
+
+ self._conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=1 if is_vd_mode else stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ bias=False,
+ )
+
+ self._batch_norm = nn.BatchNorm2d(out_channels, )
+ if self.act is not None:
+ self._act = Activation(act_type=act, inplace=True)
+
+ def forward(self, inputs):
+ if self.is_vd_mode:
+ inputs = self._pool2d_avg(inputs)
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ if self.act is not None:
+ y = self._act(y)
+ return y
+
+
+class BottleneckBlock(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None,
+ ):
+ super(BottleneckBlock, self).__init__()
+ self.scale = 4
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ act='relu',
+ )
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ )
+ self.conv2 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels * self.scale,
+ kernel_size=1,
+ act=None,
+ )
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels * self.scale,
+ kernel_size=1,
+ stride=stride,
+ is_vd_mode=not if_first and stride[0] != 1,
+ )
+
+ self.shortcut = shortcut
+ self.out_channels = out_channels * self.scale
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+
+ conv1 = self.conv1(y)
+ conv2 = self.conv2(conv1)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = short + conv2
+ y = F.relu(y)
+ return y
+
+
+class BasicBlock(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None,
+ ):
+ super(BasicBlock, self).__init__()
+ self.stride = stride
+ self.scale = 1
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ )
+ self.conv1 = ConvBNLayer(in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ act=None)
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=stride,
+ is_vd_mode=not if_first and stride[0] != 1,
+ )
+
+ self.shortcut = shortcut
+ self.out_channels = out_channels * self.scale
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = short + conv1
+ y = F.relu(y)
+ return y
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, in_channels=3, layers=50, **kwargs):
+ super(ResNet, self).__init__()
+
+ self.layers = layers
+ supported_layers = [18, 34, 50, 101, 152, 200]
+ assert layers in supported_layers, 'supported layers are {} but input layer is {}'.format(
+ supported_layers, layers)
+
+ if layers == 18:
+ depth = [2, 2, 2, 2]
+ elif layers == 34 or layers == 50:
+ depth = [3, 4, 6, 3]
+ elif layers == 101:
+ depth = [3, 4, 23, 3]
+ elif layers == 152:
+ depth = [3, 8, 36, 3]
+ elif layers == 200:
+ depth = [3, 12, 48, 3]
+
+ if layers >= 50:
+ block_class = BottleneckBlock
+ else:
+ block_class = BasicBlock
+ num_filters = [64, 128, 256, 512]
+
+ self.conv1_1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=32,
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ )
+ self.conv1_2 = ConvBNLayer(in_channels=32,
+ out_channels=32,
+ kernel_size=3,
+ stride=1,
+ act='relu')
+ self.conv1_3 = ConvBNLayer(in_channels=32,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ act='relu')
+ self.pool2d_max = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ # self.block_list = list()
+ self.block_list = nn.Sequential()
+ in_channels = 64
+ for block in range(len(depth)):
+ shortcut = False
+ for i in range(depth[block]):
+ if layers in [101, 152, 200] and block == 2:
+ if i == 0:
+ conv_name = 'res' + str(block + 2) + 'a'
+ else:
+ conv_name = 'res' + str(block + 2) + 'b' + str(i)
+ else:
+ conv_name = 'res' + str(block + 2) + chr(97 + i)
+
+ if i == 0 and block != 0:
+ stride = (2, 1)
+ else:
+ stride = (1, 1)
+
+ block_instance = block_class(
+ in_channels=in_channels,
+ out_channels=num_filters[block],
+ stride=stride,
+ shortcut=shortcut,
+ if_first=block == i == 0,
+ name=conv_name,
+ )
+ shortcut = True
+ in_channels = block_instance.out_channels
+ # self.block_list.append(bottleneck_block)
+ self.block_list.add_module('bb_%d_%d' % (block, i),
+ block_instance)
+ self.out_channels = num_filters[block]
+ self.out_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+
+ def forward(self, inputs):
+ y = self.conv1_1(inputs)
+ y = self.conv1_2(y)
+ y = self.conv1_3(y)
+ y = self.pool2d_max(y)
+ for block in self.block_list:
+ y = block(y)
+ y = self.out_pool(y)
+
+ return y
diff --git a/openrec/modeling/encoders/repvit.py b/openrec/modeling/encoders/repvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cff40ab33832ebee04b45d9b628255fb6274973
--- /dev/null
+++ b/openrec/modeling/encoders/repvit.py
@@ -0,0 +1,338 @@
+"""
+This code is refer from:
+https://github.com/THU-MIG/RepViT
+"""
+
+import torch.nn as nn
+import torch
+from torch.nn.init import constant_
+
+
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
+ min_value = min_value or divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < round_limit * v:
+ new_v += divisor
+ return new_v
+
+
+class SEModule(nn.Module):
+ """SE Module as defined in original SE-Nets with a few additions
+ Additions include:
+ * divisor can be specified to keep channels % div == 0 (default: 8)
+ * reduction channels can be specified directly by arg (if rd_channels is set)
+ * reduction channels can be specified by float rd_ratio (default: 1/16)
+ * global max pooling can be added to the squeeze aggregation
+ * customizable activation, normalization, and gate layer
+ """
+
+ def __init__(
+ self,
+ channels,
+ rd_ratio=1.0 / 16,
+ rd_channels=None,
+ rd_divisor=8,
+ act_layer=nn.ReLU,
+ ):
+ super(SEModule, self).__init__()
+ if not rd_channels:
+ rd_channels = make_divisible(channels * rd_ratio,
+ rd_divisor,
+ round_limit=0.0)
+ self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True)
+ self.act = act_layer()
+ self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True)
+
+ def forward(self, x):
+ x_se = x.mean((2, 3), keepdim=True)
+ x_se = self.fc1(x_se)
+ x_se = self.act(x_se)
+ x_se = self.fc2(x_se)
+ return x * torch.sigmoid(x_se)
+
+
+class Conv2D_BN(nn.Sequential):
+
+ def __init__(
+ self,
+ a,
+ b,
+ ks=1,
+ stride=1,
+ pad=0,
+ dilation=1,
+ groups=1,
+ bn_weight_init=1,
+ resolution=-10000,
+ ):
+ super().__init__()
+ self.add_module(
+ 'c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups,
+ bias=False))
+ self.add_module('bn', nn.BatchNorm2d(b))
+ constant_(self.bn.weight, bn_weight_init)
+ constant_(self.bn.bias, 0)
+
+ @torch.no_grad()
+ def fuse(self):
+ c, bn = self._modules.values()
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
+ w = c.weight * w[:, None, None, None]
+ b = bn.bias - bn.running_mean * bn.weight / \
+ (bn.running_var + bn.eps)**0.5
+ m = nn.Conv2d(w.size(1) * self.c.groups,
+ w.size(0),
+ w.shape[2:],
+ stride=self.c.stride,
+ padding=self.c.padding,
+ dilation=self.c.dilation,
+ groups=self.c.groups,
+ device=c.weight.device)
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+
+class Residual(torch.nn.Module):
+
+ def __init__(self, m, drop=0.):
+ super().__init__()
+ self.m = m
+ self.drop = drop
+
+ def forward(self, x):
+ if self.training and self.drop > 0:
+ return x + self.m(x) * torch.rand(
+ x.size(0), 1, 1, 1, device=x.device).ge_(
+ self.drop).div(1 - self.drop).detach()
+ else:
+ return x + self.m(x)
+
+ @torch.no_grad()
+ def fuse(self):
+ if isinstance(self.m, Conv2D_BN):
+ m = self.m.fuse()
+ assert (m.groups == m.in_channels)
+ identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
+ identity = nn.functional.pad(identity, [1, 1, 1, 1])
+ m.weight += identity.to(m.weight.device)
+ return m
+ elif isinstance(self.m, nn.Conv2d):
+ m = self.m
+ assert (m.groups != m.in_channels)
+ identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
+ identity = nn.functional.pad(identity, [1, 1, 1, 1])
+ m.weight += identity.to(m.weight.device)
+ return m
+ else:
+ return self
+
+
+class RepVGGDW(nn.Module):
+
+ def __init__(self, ed) -> None:
+ super().__init__()
+ self.conv = Conv2D_BN(ed, ed, 3, 1, 1, groups=ed)
+ self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
+ self.dim = ed
+ self.bn = nn.BatchNorm2d(ed)
+
+ def forward(self, x):
+ return self.bn((self.conv(x) + self.conv1(x)) + x)
+
+ @torch.no_grad()
+ def fuse(self):
+ conv = self.conv.fuse()
+ conv1 = self.conv1
+
+ conv_w = conv.weight
+ conv_b = conv.bias
+ conv1_w = conv1.weight
+ conv1_b = conv1.bias
+
+ conv1_w = nn.functional.pad(conv1_w, [1, 1, 1, 1])
+
+ identity = nn.functional.pad(
+ torch.ones(conv1_w.shape[0],
+ conv1_w.shape[1],
+ 1,
+ 1,
+ device=conv1_w.device), [1, 1, 1, 1])
+
+ final_conv_w = conv_w + conv1_w + identity
+ final_conv_b = conv_b + conv1_b
+
+ conv.weight.data.copy_(final_conv_w)
+ conv.bias.data.copy_(final_conv_b)
+
+ bn = self.bn
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
+ w = conv.weight * w[:, None, None, None]
+ b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
+ (bn.running_var + bn.eps)**0.5
+ conv.weight.data.copy_(w)
+ conv.bias.data.copy_(b)
+ return conv
+
+
+class RepViTBlock(nn.Module):
+
+ def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se,
+ use_hs):
+ super(RepViTBlock, self).__init__()
+
+ self.identity = stride == 1 and inp == oup
+ assert hidden_dim == 2 * inp
+
+ if stride != 1:
+ self.token_mixer = nn.Sequential(
+ Conv2D_BN(inp,
+ inp,
+ kernel_size,
+ stride, (kernel_size - 1) // 2,
+ groups=inp),
+ SEModule(inp, 0.25) if use_se else nn.Identity(),
+ Conv2D_BN(inp, oup, ks=1, stride=1, pad=0),
+ )
+ self.channel_mixer = Residual(
+ nn.Sequential(
+ # pw
+ Conv2D_BN(oup, 2 * oup, 1, 1, 0),
+ nn.GELU() if use_hs else nn.GELU(),
+ # pw-linear
+ Conv2D_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
+ ))
+ else:
+ assert self.identity
+ self.token_mixer = nn.Sequential(
+ RepVGGDW(inp),
+ SEModule(inp, 0.25) if use_se else nn.Identity(),
+ )
+ self.channel_mixer = Residual(
+ nn.Sequential(
+ # pw
+ Conv2D_BN(inp, hidden_dim, 1, 1, 0),
+ nn.GELU() if use_hs else nn.GELU(),
+ # pw-linear
+ Conv2D_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
+ ))
+
+ def forward(self, x):
+ return self.channel_mixer(self.token_mixer(x))
+
+
+class RepViT(nn.Module):
+
+ def __init__(self, cfgs, in_channels=3, out_indices=None):
+ super(RepViT, self).__init__()
+ # setting of inverted residual blocks
+ self.cfgs = cfgs
+
+ # building first layer
+ input_channel = self.cfgs[0][2]
+ patch_embed = nn.Sequential(
+ Conv2D_BN(in_channels, input_channel // 2, 3, 2, 1),
+ nn.GELU(),
+ Conv2D_BN(input_channel // 2, input_channel, 3, 2, 1),
+ )
+ layers = [patch_embed]
+ # building inverted residual blocks
+ block = RepViTBlock
+ for k, t, c, use_se, use_hs, s in self.cfgs:
+ output_channel = _make_divisible(c, 8)
+ exp_size = _make_divisible(input_channel * t, 8)
+ layers.append(
+ block(input_channel, exp_size, output_channel, k, s, use_se,
+ use_hs))
+ input_channel = output_channel
+ self.features = nn.ModuleList(layers)
+ self.out_indices = out_indices
+ if out_indices is not None:
+ self.out_channels = [self.cfgs[ids - 1][2] for ids in out_indices]
+ else:
+ self.out_channels = self.cfgs[-1][2]
+
+ def forward(self, x):
+ if self.out_indices is not None:
+ return self.forward_det(x)
+ return self.forward_rec(x)
+
+ def forward_det(self, x):
+ outs = []
+ for i, f in enumerate(self.features):
+ x = f(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return outs
+
+ def forward_rec(self, x):
+ for f in self.features:
+ x = f(x)
+ return x
+
+
+def RepSVTREncoder(in_channels=3):
+ """
+ Constructs a MobileNetV3-Large model
+ """
+ # k, t, c, SE, HS, s
+ cfgs = [
+ [3, 2, 96, 1, 0, 1],
+ [3, 2, 96, 0, 0, 1],
+ [3, 2, 96, 0, 0, 1],
+ [3, 2, 192, 0, 1, (2, 1)],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 384, 0, 1, (2, 1)],
+ [3, 2, 384, 1, 1, 1],
+ [3, 2, 384, 0, 1, 1],
+ ]
+ return RepViT(cfgs, in_channels=in_channels)
+
+
+def RepSVTR_det(in_channels=3, out_indices=[2, 5, 10, 13]):
+ """
+ Constructs a MobileNetV3-Large model
+ """
+ # k, t, c, SE, HS, s
+ cfgs = [
+ [3, 2, 48, 1, 0, 1],
+ [3, 2, 48, 0, 0, 1],
+ [3, 2, 96, 0, 0, 2],
+ [3, 2, 96, 1, 0, 1],
+ [3, 2, 96, 0, 0, 1],
+ [3, 2, 192, 0, 1, 2],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 384, 0, 1, 2],
+ [3, 2, 384, 1, 1, 1],
+ [3, 2, 384, 0, 1, 1],
+ ]
+ return RepViT(cfgs, in_channels=in_channels, out_indices=out_indices)
diff --git a/openrec/modeling/encoders/resnet31_rnn.py b/openrec/modeling/encoders/resnet31_rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ef49bbbd61c4a6cf3d6592a8204451e06180109
--- /dev/null
+++ b/openrec/modeling/encoders/resnet31_rnn.py
@@ -0,0 +1,123 @@
+import torch.nn as nn
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding."""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution."""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False)
+
+
+class AsterBlock(nn.Module):
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(AsterBlock, self).__init__()
+ self.conv1 = conv1x1(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out += residual
+ out = self.relu(out)
+ return out
+
+
+class ResNet_ASTER(nn.Module):
+ """For aster or crnn."""
+
+ def __init__(self, in_channels, with_lstm=True, n_group=1):
+ super(ResNet_ASTER, self).__init__()
+ self.with_lstm = with_lstm
+ self.n_group = n_group
+
+ self.out_channels = 512
+ if with_lstm:
+ self.out_channels = 512
+
+ self.layer0 = nn.Sequential(
+ nn.Conv2d(in_channels,
+ 32,
+ kernel_size=(3, 3),
+ stride=1,
+ padding=1,
+ bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True))
+
+ self.inplanes = 32
+ self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50]
+ self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25]
+ self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25]
+ self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25]
+ self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25]
+
+ if with_lstm:
+ self.rnn = nn.LSTM(512,
+ 256,
+ bidirectional=True,
+ num_layers=2,
+ batch_first=True)
+ self.out_planes = 2 * 256
+ else:
+ self.out_planes = 512
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight,
+ mode='fan_out',
+ nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, planes, blocks, stride):
+ downsample = None
+ if stride != [1, 1] or self.inplanes != planes:
+ downsample = nn.Sequential(conv1x1(self.inplanes, planes, stride),
+ nn.BatchNorm2d(planes))
+
+ layers = []
+ layers.append(AsterBlock(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes
+ for _ in range(1, blocks):
+ layers.append(AsterBlock(self.inplanes, planes))
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x0 = self.layer0(x)
+ x1 = self.layer1(x0)
+ x2 = self.layer2(x1)
+ x3 = self.layer3(x2)
+ x4 = self.layer4(x3)
+ x5 = self.layer5(x4)
+
+ cnn_feat = x5.squeeze(2) # [N, c, w]
+ cnn_feat = cnn_feat.transpose(2, 1).contiguous()
+ if self.with_lstm:
+ rnn_feat, _ = self.rnn(cnn_feat)
+ return rnn_feat
+ else:
+ return cnn_feat
diff --git a/openrec/modeling/encoders/svtrnet.py b/openrec/modeling/encoders/svtrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f63d4bba415c6a727b001e36ab24bbb77c1e667b
--- /dev/null
+++ b/openrec/modeling/encoders/svtrnet.py
@@ -0,0 +1,574 @@
+import numpy as np
+import torch
+from torch import nn
+from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_
+
+from openrec.modeling.common import DropPath, Identity, Mlp
+
+
+class ConvBNLayer(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=0,
+ bias=False,
+ groups=1,
+ act=nn.GELU,
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ bias=bias,
+ )
+ self.norm = nn.BatchNorm2d(out_channels)
+ self.act = act()
+
+ def forward(self, inputs):
+ out = self.conv(inputs)
+ out = self.norm(out)
+ out = self.act(out)
+ return out
+
+
+class ConvMixer(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ HW=[8, 25],
+ local_k=[3, 3],
+ ):
+ super().__init__()
+ self.HW = HW
+ self.dim = dim
+ self.local_mixer = nn.Conv2d(dim,
+ dim,
+ local_k,
+ 1, [local_k[0] // 2, local_k[1] // 2],
+ groups=num_heads)
+
+ def forward(self, x):
+ h = self.HW[0]
+ w = self.HW[1]
+ x = x.transpose(1, 2).reshape([x.shape[0], self.dim, h, w])
+ x = self.local_mixer(x)
+ x = x.flatten(2).transpose(1, 2)
+ return x
+
+
+class Attention(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ mixer='Global',
+ HW=None,
+ local_k=[7, 11],
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.dim = dim
+ self.head_dim = dim // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.HW = HW
+ if HW is not None:
+ H = HW[0]
+ W = HW[1]
+ self.N = H * W
+ self.C = dim
+ if mixer == 'Local' and HW is not None:
+ hk = local_k[0]
+ wk = local_k[1]
+ mask = torch.ones(H * W,
+ H + hk - 1,
+ W + wk - 1,
+ dtype=torch.float32,
+ requires_grad=False)
+ for h in range(0, H):
+ for w in range(0, W):
+ mask[h * W + w, h:h + hk, w:w + wk] = 0.0
+ mask = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk // 2].flatten(1)
+ mask[mask >= 1] = -np.inf
+ self.register_buffer('mask', mask[None, None, :, :])
+ self.mixer = mixer
+
+ def forward(self, x):
+ B, N, _ = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
+ self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+
+ # x = F.scaled_dot_product_attention(
+ # q, k, v,
+ # attn_mask=mask,
+ # dropout_p=self.attn_drop.p
+ # )
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+ if self.mixer == 'Local':
+ attn += self.mask
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, self.dim)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mixer='Global',
+ local_mixer=[7, 11],
+ HW=None,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer='nn.LayerNorm',
+ eps=1e-6,
+ prenorm=True,
+ ):
+ super().__init__()
+ if isinstance(norm_layer, str):
+ self.norm1 = eval(norm_layer)(dim, eps=eps)
+ else:
+ self.norm1 = norm_layer(dim)
+ if mixer == 'Global' or mixer == 'Local':
+ self.mixer = Attention(
+ dim,
+ num_heads=num_heads,
+ mixer=mixer,
+ HW=HW,
+ local_k=local_mixer,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ elif mixer == 'Conv':
+ self.mixer = ConvMixer(dim,
+ num_heads=num_heads,
+ HW=HW,
+ local_k=local_mixer)
+ else:
+ raise TypeError('The mixer must be one of [Global, Local, Conv]')
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ if isinstance(norm_layer, str):
+ self.norm2 = eval(norm_layer)(dim, eps=eps)
+ else:
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp_ratio = mlp_ratio
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+ self.prenorm = prenorm
+
+ def forward(self, x):
+ if self.prenorm:
+ x = self.norm1(x + self.drop_path(self.mixer(x)))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ else:
+ x = x + self.drop_path(self.mixer(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding."""
+
+ def __init__(
+ self,
+ img_size=[32, 100],
+ in_channels=3,
+ embed_dim=768,
+ sub_num=2,
+ patch_size=[4, 4],
+ mode='pope',
+ ):
+ super().__init__()
+ num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] //
+ (2**sub_num))
+ self.img_size = img_size
+ self.num_patches = num_patches
+ self.embed_dim = embed_dim
+ self.norm = None
+ if mode == 'pope':
+ if sub_num == 2:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ )
+ if sub_num == 3:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 4,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ )
+ elif mode == 'linear':
+ self.proj = nn.Conv2d(1,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=patch_size)
+ self.num_patches = img_size[0] // patch_size[0] * img_size[
+ 1] // patch_size[1]
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ assert (
+ H == self.img_size[0] and W == self.img_size[1]
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class SubSample(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ types='Pool',
+ stride=[2, 1],
+ sub_norm='nn.LayerNorm',
+ act=None,
+ ):
+ super().__init__()
+ self.types = types
+ if types == 'Pool':
+ self.avgpool = nn.AvgPool2d(kernel_size=[3, 5],
+ stride=stride,
+ padding=[1, 2])
+ self.maxpool = nn.MaxPool2d(kernel_size=[3, 5],
+ stride=stride,
+ padding=[1, 2])
+ self.proj = nn.Linear(in_channels, out_channels)
+ else:
+ self.conv = nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1)
+ self.norm = eval(sub_norm)(out_channels)
+ if act is not None:
+ self.act = act()
+ else:
+ self.act = None
+
+ def forward(self, x):
+ if self.types == 'Pool':
+ x1 = self.avgpool(x)
+ x2 = self.maxpool(x)
+ x = (x1 + x2) * 0.5
+ out = self.proj(x.flatten(2).transpose(1, 2))
+ else:
+ x = self.conv(x)
+ out = x.flatten(2).transpose(1, 2)
+ out = self.norm(out)
+ if self.act is not None:
+ out = self.act(out)
+
+ return out
+
+
+class SVTRNet(nn.Module):
+
+ def __init__(
+ self,
+ img_size=[32, 100],
+ in_channels=3,
+ embed_dim=[64, 128, 256],
+ depth=[3, 6, 3],
+ num_heads=[2, 4, 8],
+ mixer=['Local'] * 6 +
+ ['Global'] * 6, # Local atten, Global atten, Conv
+ local_mixer=[[7, 11], [7, 11], [7, 11]],
+ patch_merging='Conv', # Conv, Pool, None
+ sub_k=[[2, 1], [2, 1]],
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ last_drop=0.1,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer='nn.LayerNorm',
+ sub_norm='nn.LayerNorm',
+ eps=1e-6,
+ out_channels=192,
+ out_char_num=25,
+ block_unit='Block',
+ act='nn.GELU',
+ last_stage=True,
+ sub_num=2,
+ prenorm=True,
+ use_lenhead=False,
+ feature2d=False,
+ **kwargs,
+ ):
+ super().__init__()
+ self.img_size = img_size
+ self.embed_dim = embed_dim
+ self.out_channels = out_channels
+ self.prenorm = prenorm
+ self.feature2d = feature2d
+ patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ in_channels=in_channels,
+ embed_dim=embed_dim[0],
+ sub_num=sub_num,
+ )
+ num_patches = self.patch_embed.num_patches
+ self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
+ self.hw = [
+ [self.HW[0] // sub_k[0][0], self.HW[1] // sub_k[0][1]],
+ [
+ self.HW[0] // (sub_k[0][0] * sub_k[1][0]),
+ self.HW[1] // (sub_k[0][1] * sub_k[1][1])
+ ],
+ ]
+ # self.pos_embed = self.create_parameter(
+ # shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
+ # self.add_parameter("pos_embed", self.pos_embed)
+ self.pos_embed = nn.Parameter(
+ torch.zeros([1, num_patches, embed_dim[0]], dtype=torch.float32),
+ requires_grad=True,
+ )
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ Block_unit = eval(block_unit)
+
+ dpr = np.linspace(0, drop_path_rate, sum(depth))
+ self.blocks1 = nn.ModuleList([
+ Block_unit(
+ dim=embed_dim[0],
+ num_heads=num_heads[0],
+ mixer=mixer[0:depth[0]][i],
+ HW=self.HW,
+ local_mixer=local_mixer[0],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[0:depth[0]][i],
+ norm_layer=norm_layer,
+ eps=eps,
+ prenorm=prenorm,
+ ) for i in range(depth[0])
+ ])
+ if patch_merging is not None:
+ self.sub_sample1 = SubSample(
+ embed_dim[0],
+ embed_dim[1],
+ sub_norm=sub_norm,
+ stride=sub_k[0],
+ types=patch_merging,
+ )
+ HW = self.hw[0]
+ else:
+ HW = self.HW
+ self.patch_merging = patch_merging
+ self.blocks2 = nn.ModuleList([
+ Block_unit(
+ dim=embed_dim[1],
+ num_heads=num_heads[1],
+ mixer=mixer[depth[0]:depth[0] + depth[1]][i],
+ HW=HW,
+ local_mixer=local_mixer[1],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[depth[0]:depth[0] + depth[1]][i],
+ norm_layer=norm_layer,
+ eps=eps,
+ prenorm=prenorm,
+ ) for i in range(depth[1])
+ ])
+ if patch_merging is not None:
+ self.sub_sample2 = SubSample(
+ embed_dim[1],
+ embed_dim[2],
+ sub_norm=sub_norm,
+ stride=sub_k[1],
+ types=patch_merging,
+ )
+ HW = self.hw[1]
+ self.blocks3 = nn.ModuleList([
+ Block_unit(
+ dim=embed_dim[2],
+ num_heads=num_heads[2],
+ mixer=mixer[depth[0] + depth[1]:][i],
+ HW=HW,
+ local_mixer=local_mixer[2],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[depth[0] + depth[1]:][i],
+ norm_layer=norm_layer,
+ eps=eps,
+ prenorm=prenorm,
+ ) for i in range(depth[2])
+ ])
+ self.last_stage = last_stage
+ if last_stage:
+ self.avg_pool = nn.AdaptiveAvgPool2d([1, out_char_num])
+ self.last_conv = nn.Conv2d(
+ in_channels=embed_dim[2],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ self.hardswish = nn.Hardswish()
+ self.dropout = nn.Dropout(p=last_drop)
+ else:
+ self.out_channels = embed_dim[2]
+ if not prenorm:
+ self.norm = eval(norm_layer)(embed_dim[-1], eps=eps)
+ self.use_lenhead = use_lenhead
+ if use_lenhead:
+ self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
+ self.hardswish_len = nn.Hardswish()
+ self.dropout_len = nn.Dropout(p=last_drop)
+
+ trunc_normal_(self.pos_embed, mean=0, std=0.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, mean=0, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ if isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+ if isinstance(m, nn.Conv2d):
+ kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'sub_sample1', 'sub_sample2', 'sub_sample3'}
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+ for blk in self.blocks1:
+ x = blk(x)
+ if self.patch_merging is not None:
+ x = self.sub_sample1(
+ x.transpose(1, 2).reshape(-1, self.embed_dim[0], self.HW[0],
+ self.HW[1]))
+ for blk in self.blocks2:
+ x = blk(x)
+ if self.patch_merging is not None:
+ x = self.sub_sample2(
+ x.transpose(1, 2).reshape(-1, self.embed_dim[1], self.hw[0][0],
+ self.hw[0][1]))
+ for blk in self.blocks3:
+ x = blk(x)
+ if not self.prenorm:
+ x = self.norm(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ if self.feature2d:
+ x = x.transpose(1, 2).reshape(-1, self.embed_dim[2], self.hw[1][0],
+ self.hw[1][1])
+ if self.use_lenhead:
+ len_x = self.len_conv(x.mean(1))
+ len_x = self.dropout_len(self.hardswish_len(len_x))
+ if self.last_stage:
+ x = self.avg_pool(
+ x.transpose(1, 2).reshape(-1, self.embed_dim[2], self.hw[1][0],
+ self.hw[1][1]))
+ x = self.last_conv(x)
+ x = self.hardswish(x)
+ x = self.dropout(x)
+ x = x.flatten(2).transpose(1, 2)
+ if self.use_lenhead:
+ return x, len_x
+ return x
diff --git a/openrec/modeling/encoders/svtrnet2dpos.py b/openrec/modeling/encoders/svtrnet2dpos.py
new file mode 100644
index 0000000000000000000000000000000000000000..b027d9d0c9e657cca5b079f3c5742166bfa46319
--- /dev/null
+++ b/openrec/modeling/encoders/svtrnet2dpos.py
@@ -0,0 +1,616 @@
+import numpy as np
+import torch
+from torch import nn
+from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_
+
+from openrec.modeling.common import DropPath, Identity, Mlp
+
+
+class ConvBNLayer(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=0,
+ bias=False,
+ groups=1,
+ act=nn.GELU,
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ bias=bias,
+ )
+ self.norm = nn.BatchNorm2d(out_channels)
+ self.act = act()
+
+ def forward(self, inputs):
+ out = self.conv(inputs)
+ out = self.norm(out)
+ out = self.act(out)
+ return out
+
+
+class ConvMixer(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ HW=[8, 25],
+ local_k=[3, 3],
+ ):
+ super().__init__()
+ self.HW = HW
+ self.dim = dim
+ self.local_mixer = nn.Conv2d(dim,
+ dim,
+ local_k,
+ 1, [local_k[0] // 2, local_k[1] // 2],
+ groups=num_heads)
+
+ def forward(self, x, w):
+ x = x.transpose(1, 2).reshape([x.shape[0], self.dim, -1, w])
+ x = self.local_mixer(x)
+ x = x.flatten(2).transpose(1, 2)
+ return x
+
+
+class ConvMlp(nn.Module):
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ groups=1,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1, groups=groups)
+ self.act = act_layer()
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class ConvBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mixer='Global',
+ local_mixer=[7, 11],
+ HW=None,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer='nn.LayerNorm',
+ eps=1e-6,
+ prenorm=True,
+ ):
+ super().__init__()
+
+ self.norm1 = nn.BatchNorm2d(dim)
+ self.local_mixer = nn.Conv2d(dim,
+ dim, [5, 5],
+ 1, [2, 2],
+ groups=num_heads)
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ self.norm2 = nn.BatchNorm2d(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ConvMlp(in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop)
+ self.prenorm = prenorm
+
+ def forward(self, x):
+ x = self.norm1(x + self.drop_path(self.local_mixer(x)))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ return x
+
+
+class Attention(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ mixer='Global',
+ HW=None,
+ local_k=[7, 11],
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.dim = dim
+ self.head_dim = dim // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.HW = HW
+ if HW is not None:
+ H = HW[0]
+ W = HW[1]
+ if W == -1:
+ W = 300
+ self.C = dim
+ self.H = H
+ self.W = W
+ if mixer == 'Local' and HW is not None:
+ if HW[1] == -1:
+ wk = 29
+ else:
+ wk = local_k[1]
+ self.wk = wk
+ mask = torch.ones(W, W, dtype=torch.float32, requires_grad=False)
+
+ for w in range(0, W):
+ b_w = w - wk // 2 if w - wk // 2 > 0 else 0
+ if b_w > W - wk:
+ b_w = W - wk
+ mask[w, b_w:b_w + wk] = 0.0
+ mask[mask >= 1] = -np.inf
+
+ self.register_buffer('mask', mask)
+ self.mixer = mixer
+
+ def forward(self, x, w):
+ B, N, _ = x.shape
+ h = N // w
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
+ self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+ if self.mixer == 'Local' and w >= 32:
+ mask1 = self.mask[(self.W - w) // 2:-(self.W - w) // 2,
+ (self.W - w) // 2:-(self.W - w) // 2]
+ mask1[:(self.wk // 2 + 1)] = self.mask[:(self.wk // 2 + 1), :w]
+ mask1[-(self.wk // 2 + 1):] = self.mask[-(self.wk // 2 + 1):, -w:]
+ attn += mask1[None, None, :, :].tile(B, 1, h, h)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, self.dim)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mixer='Global',
+ local_mixer=[7, 11],
+ HW=None,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer='nn.LayerNorm',
+ eps=1e-6,
+ ):
+ super().__init__()
+ if isinstance(norm_layer, str):
+ self.norm1 = eval(norm_layer)(dim, eps=eps)
+ else:
+ self.norm1 = norm_layer(dim)
+ if mixer == 'Global' or mixer == 'Local':
+ self.mixer = Attention(
+ dim,
+ num_heads=num_heads,
+ mixer=mixer,
+ HW=HW,
+ local_k=local_mixer,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ elif mixer == 'Conv':
+ self.mixer = ConvMixer(dim,
+ num_heads=num_heads,
+ HW=HW,
+ local_k=local_mixer)
+ else:
+ raise TypeError('The mixer must be one of [Global, Local, Conv]')
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ if isinstance(norm_layer, str):
+ self.norm2 = eval(norm_layer)(dim, eps=eps)
+ else:
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp_ratio = mlp_ratio
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, x, w):
+ x = self.norm1(x + self.drop_path(self.mixer(x, w)))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ return x, w
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding."""
+
+ def __init__(
+ self,
+ img_size=[32, 100],
+ in_channels=3,
+ embed_dim=768,
+ sub_num=2,
+ patch_size=[4, 4],
+ mode='pope',
+ ):
+ super().__init__()
+ num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] //
+ (2**sub_num))
+ self.img_size = img_size
+ self.num_patches = num_patches
+ self.embed_dim = embed_dim
+ self.norm = None
+ if mode == 'pope':
+ if sub_num == 2:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ )
+ if sub_num == 3:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 4,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ )
+ elif mode == 'linear':
+ self.proj = nn.Conv2d(1,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=patch_size)
+ self.num_patches = img_size[0] // patch_size[0] * img_size[
+ 1] // patch_size[1]
+
+ def forward(self, x):
+ x = self.proj(x)
+ return x
+
+
+class SubSample(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ types='Pool',
+ stride=[2, 1],
+ sub_norm='nn.LayerNorm',
+ act=None,
+ ):
+ super().__init__()
+ self.types = types
+ if types == 'Pool':
+ self.avgpool = nn.AvgPool2d(kernel_size=[3, 5],
+ stride=stride,
+ padding=[1, 2])
+ self.maxpool = nn.MaxPool2d(kernel_size=[3, 5],
+ stride=stride,
+ padding=[1, 2])
+ self.proj = nn.Linear(in_channels, out_channels)
+ else:
+ self.conv = nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1)
+ self.dim = in_channels
+ self.norm = eval(sub_norm)(out_channels)
+ if act is not None:
+ self.act = act()
+ else:
+ self.act = None
+
+ def forward(self, x, w):
+ if self.types == 'Pool':
+ x1 = self.avgpool(x)
+ x2 = self.maxpool(x)
+ x = (x1 + x2) * 0.5
+ out = self.proj(x.flatten(2).transpose(1, 2))
+ else:
+ x = x.transpose(1, 2).reshape([x.shape[0], self.dim, -1, w])
+ x = self.conv(x)
+ out = x.flatten(2).transpose(1, 2)
+ out = self.norm(out)
+ if self.act is not None:
+ out = self.act(out)
+
+ return out, w
+
+
+class FlattenTranspose(nn.Module):
+
+ def forward(self, x):
+ return x.flatten(2).transpose(1, 2)
+
+
+class DownSConv(nn.Module):
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels,
+ out_channels,
+ 3,
+ stride=[2, 1],
+ padding=1)
+ self.norm = nn.LayerNorm(out_channels)
+
+ def forward(self, x, w):
+ B, N, C = x.shape
+ x = x.transpose(1, 2).reshape(B, C, -1, w)
+ x = self.conv(x)
+ w = x.shape[-1]
+ x = self.norm(x.flatten(2).transpose(1, 2))
+ return x, w
+
+
+class SVTRNet2DPos(nn.Module):
+
+ def __init__(
+ self,
+ img_size=[32, -1],
+ in_channels=3,
+ embed_dim=[64, 128, 256],
+ depth=[3, 6, 3],
+ num_heads=[2, 4, 8],
+ mixer=['Local'] * 6 +
+ ['Global'] * 6, # Local atten, Global atten, Conv
+ local_mixer=[[7, 11], [7, 11], [7, 11]],
+ patch_merging='Conv', # Conv, Pool, None
+ pool_size=[2, 1],
+ max_size=[16, 32],
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ last_drop=0.1,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer='nn.LayerNorm',
+ eps=1e-6,
+ act='nn.GELU',
+ last_stage=True,
+ sub_num=2,
+ use_first_sub=True,
+ flatten=False,
+ **kwargs,
+ ):
+ super().__init__()
+ self.img_size = img_size
+ self.embed_dim = embed_dim
+ self.flatten = flatten
+ patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ in_channels=in_channels,
+ embed_dim=embed_dim[0],
+ sub_num=sub_num,
+ )
+ if img_size[1] == -1:
+ self.HW = [img_size[0] // (2**sub_num), -1]
+ else:
+ self.HW = [
+ img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)
+ ]
+ pos_embed = torch.zeros([1, max_size[0] * max_size[1], embed_dim[0]],
+ dtype=torch.float32)
+ trunc_normal_(pos_embed, mean=0, std=0.02)
+ self.pos_embed = nn.Parameter(
+ pos_embed.transpose(1, 2).reshape(1, embed_dim[0], max_size[0],
+ max_size[1]),
+ requires_grad=True,
+ )
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ conv_block_num = sum(
+ [1 if mixer_type == 'ConvB' else 0 for mixer_type in mixer])
+ Block_unit = [ConvBlock for _ in range(conv_block_num)
+ ] + [Block for _ in range(len(mixer) - conv_block_num)]
+ HW = self.HW
+ dpr = np.linspace(0, drop_path_rate, sum(depth))
+ self.conv_blocks1 = nn.ModuleList([
+ Block_unit[0:depth[0]][i](
+ dim=embed_dim[0],
+ num_heads=num_heads[0],
+ mixer=mixer[0:depth[0]][i],
+ HW=self.HW,
+ local_mixer=local_mixer[0],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[0:depth[0]][i],
+ norm_layer=norm_layer,
+ eps=eps,
+ ) for i in range(depth[0])
+ ])
+ if patch_merging is not None:
+ if use_first_sub:
+ stride = [2, 1]
+ HW = [self.HW[0] // 2, self.HW[1]]
+ else:
+ stride = [1, 1]
+ HW = self.HW
+
+ sub_sample1 = nn.Sequential(
+ nn.Conv2d(embed_dim[0],
+ embed_dim[1],
+ 3,
+ stride=stride,
+ padding=1),
+ nn.BatchNorm2d(embed_dim[1]),
+ )
+ self.conv_blocks1.append(sub_sample1)
+
+ self.patch_merging = patch_merging
+ self.trans_blocks = nn.ModuleList()
+ for i in range(depth[1]):
+ block = Block_unit[depth[0]:depth[0] + depth[1]][i](
+ dim=embed_dim[1],
+ num_heads=num_heads[1],
+ mixer=mixer[depth[0]:depth[0] + depth[1]][i],
+ HW=HW,
+ local_mixer=local_mixer[1],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[depth[0]:depth[0] + depth[1]][i],
+ norm_layer=norm_layer,
+ eps=eps,
+ )
+ if i + depth[0] < conv_block_num:
+ self.conv_blocks1.append(block)
+ else:
+ self.trans_blocks.append(block)
+ if patch_merging is not None:
+ self.trans_blocks.append(DownSConv(embed_dim[1], embed_dim[2]))
+ HW = [HW[0] // 2, -1]
+
+ for i in range(depth[2]):
+ self.trans_blocks.append(Block_unit[depth[0] + depth[1]:][i](
+ dim=embed_dim[2],
+ num_heads=num_heads[2],
+ mixer=mixer[depth[0] + depth[1]:][i],
+ HW=HW,
+ local_mixer=local_mixer[2],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[depth[0] + depth[1]:][i],
+ norm_layer=norm_layer,
+ eps=eps,
+ ))
+ self.last_stage = last_stage
+ self.out_channels = embed_dim[-1]
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, mean=0, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ if isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+ if isinstance(m, nn.Conv2d):
+ kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'sub_sample1', 'sub_sample2'}
+
+ def forward(self, x):
+ x = self.patch_embed(x)
+
+ w = x.shape[-1]
+ x = x + self.pos_embed[:, :, :x.shape[-2], :w]
+
+ for blk in self.conv_blocks1:
+ x = blk(x)
+
+ x = x.flatten(2).transpose(1, 2)
+ for blk in self.trans_blocks:
+ x, w = blk(x, w)
+ B, N, C = x.shape
+ if not self.flatten:
+ x = x.transpose(1, 2).reshape(B, C, -1, w)
+
+ return x
diff --git a/openrec/modeling/encoders/svtrv2.py b/openrec/modeling/encoders/svtrv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..abd00cc6ffaed4058fa0c6bd0fee920a75eb3e65
--- /dev/null
+++ b/openrec/modeling/encoders/svtrv2.py
@@ -0,0 +1,470 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_
+
+from openrec.modeling.common import DropPath, Identity, Mlp
+
+
+class ConvBNLayer(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=0,
+ bias=False,
+ groups=1,
+ act=nn.GELU,
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ bias=bias,
+ )
+ self.norm = nn.BatchNorm2d(out_channels)
+ self.act = act()
+
+ def forward(self, inputs):
+ out = self.conv(inputs)
+ out = self.norm(out)
+ out = self.act(out)
+ return out
+
+
+class ConvMixer(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ local_k=[5, 5],
+ ):
+ super().__init__()
+ self.local_mixer = nn.Conv2d(dim, dim, 5, 1, 2, groups=num_heads)
+
+ def forward(self, x, mask=None):
+ x = self.local_mixer(x)
+ return x
+
+
+class ConvMlp(nn.Module):
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ groups=1,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1, groups=groups)
+ self.act = act_layer()
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.dim = dim
+ self.head_dim = dim // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, mask=None):
+ B, N, _ = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
+ self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+ attn = q @ k.transpose(-2, -1) * self.scale
+ if mask is not None:
+ attn += mask.unsqueeze(0)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+ x = x.transpose(1, 2).reshape(B, N, self.dim)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mixer='Global',
+ local_k=[7, 11],
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ eps=1e-6,
+ ):
+ super().__init__()
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ if mixer == 'Global' or mixer == 'Local':
+ self.norm1 = norm_layer(dim, eps=eps)
+ self.mixer = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.norm2 = norm_layer(dim, eps=eps)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+ elif mixer == 'Conv':
+ self.norm1 = nn.BatchNorm2d(dim)
+ self.mixer = ConvMixer(dim, num_heads=num_heads, local_k=local_k)
+ self.norm2 = nn.BatchNorm2d(dim)
+ self.mlp = ConvMlp(in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop)
+ else:
+ raise TypeError('The mixer must be one of [Global, Local, Conv]')
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+
+ def forward(self, x, mask=None):
+ x = self.norm1(x + self.drop_path(self.mixer(x, mask=mask)))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ return x
+
+
+class FlattenTranspose(nn.Module):
+
+ def forward(self, x, mask=None):
+ return x.flatten(2).transpose(1, 2)
+
+
+class SVTRStage(nn.Module):
+
+ def __init__(self,
+ feat_maxSize=[16, 128],
+ dim=64,
+ out_dim=256,
+ depth=3,
+ mixer=['Local'] * 3,
+ local_k=[7, 11],
+ sub_k=[2, 1],
+ num_heads=2,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path=[0.1] * 3,
+ norm_layer=nn.LayerNorm,
+ act=nn.GELU,
+ eps=1e-6,
+ downsample=None,
+ **kwargs):
+ super().__init__()
+ self.dim = dim
+
+ conv_block_num = sum([1 if mix == 'Conv' else 0 for mix in mixer])
+ if conv_block_num == depth:
+ self.mask = None
+ conv_block_num = 0
+ if downsample:
+ self.sub_norm = nn.BatchNorm2d(out_dim, eps=eps)
+ else:
+ if 'Local' in mixer:
+ mask = self.get_max2d_mask(feat_maxSize[0], feat_maxSize[1],
+ local_k)
+ self.register_buffer('mask', mask)
+ else:
+ self.mask = None
+ if downsample:
+ self.sub_norm = norm_layer(out_dim, eps=eps)
+ self.blocks = nn.ModuleList()
+ for i in range(depth):
+ self.blocks.append(
+ Block(
+ dim=dim,
+ num_heads=num_heads,
+ mixer=mixer[i],
+ local_k=local_k,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=act,
+ attn_drop=attn_drop_rate,
+ drop_path=drop_path[i],
+ norm_layer=norm_layer,
+ eps=eps,
+ ))
+ if i == conv_block_num - 1:
+ self.blocks.append(FlattenTranspose())
+
+ if downsample:
+ self.downsample = nn.Conv2d(dim,
+ out_dim,
+ kernel_size=3,
+ stride=sub_k,
+ padding=1)
+ else:
+ self.downsample = None
+
+ def get_max2d_mask(self, H, W, local_k):
+ hk, wk = local_k
+ mask = torch.ones(H * W,
+ H + hk - 1,
+ W + wk - 1,
+ dtype=torch.float32,
+ requires_grad=False)
+ for h in range(0, H):
+ for w in range(0, W):
+ mask[h * W + w, h:h + hk, w:w + wk] = 0.0
+ mask = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk // 2] # .flatten(1)
+ mask[mask >= 1] = -np.inf
+ return mask.reshape(H, W, H, W)
+
+ def get_2d_mask(self, H1, W1):
+
+ if H1 == self.mask.shape[0] and W1 == self.mask.shape[1]:
+ return self.mask.flatten(0, 1).flatten(1, 2).unsqueeze(0)
+ h_slice = H1 // 2
+ offet_h = H1 - 2 * h_slice
+ w_slice = W1 // 2
+ offet_w = W1 - 2 * w_slice
+ mask1 = self.mask[:h_slice + offet_h, :w_slice, :H1, :W1]
+ mask2 = self.mask[:h_slice + offet_h, -w_slice:, :H1, -W1:]
+ mask3 = self.mask[-h_slice:, :(w_slice + offet_w), -H1:, :W1]
+ mask4 = self.mask[-h_slice:, -(w_slice + offet_w):, -H1:, -W1:]
+
+ mask_top = torch.concat([mask1, mask2], 1)
+ mask_bott = torch.concat([mask3, mask4], 1)
+ mask = torch.concat([mask_top.flatten(2), mask_bott.flatten(2)], 0)
+ return mask.flatten(0, 1).unsqueeze(0)
+
+ def forward(self, x, sz=None):
+ if self.mask is not None:
+ mask = self.get_2d_mask(sz[0], sz[1])
+ else:
+ mask = self.mask
+ for blk in self.blocks:
+ x = blk(x, mask=mask)
+
+ if self.downsample is not None:
+ if x.dim() == 3:
+ x = x.transpose(1, 2).reshape(-1, self.dim, sz[0], sz[1])
+ x = self.downsample(x)
+ sz = x.shape[2:]
+ x = x.flatten(2).transpose(1, 2)
+ else:
+ x = self.downsample(x)
+ sz = x.shape[2:]
+ x = self.sub_norm(x)
+ return x, sz
+
+
+class POPatchEmbed(nn.Module):
+ """Image to Patch Embedding."""
+
+ def __init__(self,
+ in_channels=3,
+ feat_max_size=[8, 32],
+ embed_dim=768,
+ use_pos_embed=False,
+ flatten=False):
+ super().__init__()
+ self.patch_embed = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ )
+ self.use_pos_embed = use_pos_embed
+ self.flatten = flatten
+ if use_pos_embed:
+ pos_embed = torch.zeros(
+ [1, feat_max_size[0] * feat_max_size[1], embed_dim],
+ dtype=torch.float32)
+ trunc_normal_(pos_embed, mean=0, std=0.02)
+ self.pos_embed = nn.Parameter(
+ pos_embed.transpose(1,
+ 2).reshape(1, embed_dim, feat_max_size[0],
+ feat_max_size[1]),
+ requires_grad=True,
+ )
+
+ def forward(self, x):
+ x = self.patch_embed(x)
+ sz = x.shape[2:]
+ if self.use_pos_embed:
+ x = x + self.pos_embed[:, :, :sz[0], :sz[1]]
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2)
+ return x, sz
+
+
+class SVTRv2(nn.Module):
+
+ def __init__(self,
+ max_sz=[32, 128],
+ in_channels=3,
+ out_channels=192,
+ depths=[3, 6, 3],
+ dims=[64, 128, 256],
+ mixer=[['Local'] * 3, ['Local'] * 3 + ['Global'] * 3,
+ ['Global'] * 3],
+ use_pos_embed=True,
+ local_k=[[7, 11], [7, 11], [-1, -1]],
+ sub_k=[[1, 1], [2, 1], [1, 1]],
+ num_heads=[2, 4, 8],
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ last_drop=0.1,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ act=nn.GELU,
+ last_stage=False,
+ eps=1e-6,
+ **kwargs):
+ super().__init__()
+ num_stages = len(depths)
+ self.num_features = dims[-1]
+
+ feat_max_size = [max_sz[0] // 4, max_sz[1] // 4]
+ self.pope = POPatchEmbed(in_channels=in_channels,
+ feat_max_size=feat_max_size,
+ embed_dim=dims[0],
+ use_pos_embed=use_pos_embed,
+ flatten=mixer[0][0] != 'Conv')
+
+ dpr = np.linspace(0, drop_path_rate,
+ sum(depths)) # stochastic depth decay rule
+
+ self.stages = nn.ModuleList()
+ for i_stage in range(num_stages):
+ stage = SVTRStage(
+ feat_maxSize=feat_max_size,
+ dim=dims[i_stage],
+ out_dim=dims[i_stage + 1] if i_stage < num_stages - 1 else 0,
+ depth=depths[i_stage],
+ mixer=mixer[i_stage],
+ local_k=local_k[i_stage],
+ sub_k=sub_k[i_stage],
+ num_heads=num_heads[i_stage],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_stage]):sum(depths[:i_stage + 1])],
+ norm_layer=norm_layer,
+ act=act,
+ downsample=False if i_stage == num_stages - 1 else True,
+ eps=eps,
+ )
+ self.stages.append(stage)
+ feat_max_size = [
+ feat_max_size[0] // sub_k[i_stage][0],
+ feat_max_size[1] // sub_k[i_stage][1]
+ ]
+
+ self.out_channels = self.num_features
+ self.last_stage = last_stage
+ if last_stage:
+ self.out_channels = out_channels
+ self.last_conv = nn.Linear(self.num_features,
+ self.out_channels,
+ bias=False)
+ self.hardswish = nn.Hardswish()
+ self.dropout = nn.Dropout(p=last_drop)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m: nn.Module):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, mean=0, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ if isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+ if isinstance(m, nn.Conv2d):
+ kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'patch_embed', 'downsample', 'pos_embed'}
+
+ def forward(self, x):
+ x, sz = self.pope(x)
+
+ for stage in self.stages:
+ x, sz = stage(x, sz)
+
+ if self.last_stage:
+ x = x.reshape(-1, sz[0], sz[1], self.num_features)
+ x = x.mean(1)
+ x = self.last_conv(x)
+ x = self.hardswish(x)
+ x = self.dropout(x)
+
+ return x
diff --git a/openrec/modeling/encoders/svtrv2_lnconv.py b/openrec/modeling/encoders/svtrv2_lnconv.py
new file mode 100644
index 0000000000000000000000000000000000000000..beb1528dabeba6119bed05af07533c36abb8ba98
--- /dev/null
+++ b/openrec/modeling/encoders/svtrv2_lnconv.py
@@ -0,0 +1,503 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_
+
+from openrec.modeling.common import DropPath, Identity, Mlp
+
+
+class ConvBNLayer(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=0,
+ bias=False,
+ groups=1,
+ act=nn.GELU,
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ bias=bias,
+ )
+ self.norm = nn.BatchNorm2d(out_channels)
+ self.act = act()
+
+ def forward(self, inputs):
+ out = self.conv(inputs)
+ out = self.norm(out)
+ out = self.act(out)
+ return out
+
+
+class Attention(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.dim = dim
+ self.head_dim = dim // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, _ = x.shape
+ qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads,
+ self.head_dim).permute(2, 0, 3, 1, 4))
+ q, k, v = qkv.unbind(0)
+ attn = q @ k.transpose(-2, -1) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+ x = x.transpose(1, 2).reshape(B, N, self.dim)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ eps=1e-6,
+ ):
+ super().__init__()
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.norm1 = norm_layer(dim, eps=eps)
+ self.mixer = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ self.norm2 = norm_layer(dim, eps=eps)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, x):
+ x = self.norm1(x + self.drop_path(self.mixer(x)))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ return x
+
+
+class ConvBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ eps=1e-6,
+ ):
+ super().__init__()
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.norm1 = norm_layer(dim, eps=eps)
+ self.mixer = nn.Conv2d(dim, dim, 5, 1, 2, groups=num_heads)
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ self.norm2 = norm_layer(dim, eps=eps)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, x):
+ C, H, W = x.shape[1:]
+ x = x + self.drop_path(self.mixer(x))
+ x = self.norm1(x.flatten(2).transpose(1, 2))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ x = x.transpose(1, 2).reshape(-1, C, H, W)
+ return x
+
+
+class FlattenTranspose(nn.Module):
+
+ def forward(self, x):
+ return x.flatten(2).transpose(1, 2)
+
+
+class SubSample2D(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride=[2, 1],
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1)
+ self.norm = nn.LayerNorm(out_channels)
+
+ def forward(self, x, sz):
+ # print(x.shape)
+ x = self.conv(x)
+ C, H, W = x.shape[1:]
+ x = self.norm(x.flatten(2).transpose(1, 2))
+ x = x.transpose(1, 2).reshape(-1, C, H, W)
+ return x, [H, W]
+
+
+class SubSample1D(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride=[2, 1],
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1)
+ self.norm = nn.LayerNorm(out_channels)
+
+ def forward(self, x, sz):
+ C = x.shape[-1]
+ x = x.transpose(1, 2).reshape(-1, C, sz[0], sz[1])
+ x = self.conv(x)
+ C, H, W = x.shape[1:]
+ x = self.norm(x.flatten(2).transpose(1, 2))
+ return x, [H, W]
+
+
+class IdentitySize(nn.Module):
+
+ def forward(self, x, sz):
+ return x, sz
+
+
+class SVTRStage(nn.Module):
+
+ def __init__(self,
+ dim=64,
+ out_dim=256,
+ depth=3,
+ mixer=['Local'] * 3,
+ sub_k=[2, 1],
+ num_heads=2,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path=[0.1] * 3,
+ norm_layer=nn.LayerNorm,
+ act=nn.GELU,
+ eps=1e-6,
+ downsample=None,
+ **kwargs):
+ super().__init__()
+ self.dim = dim
+
+ conv_block_num = sum([1 if mix == 'Conv' else 0 for mix in mixer])
+ self.blocks = nn.Sequential()
+ for i in range(depth):
+ if mixer[i] == 'Conv':
+ self.blocks.append(
+ ConvBlock(
+ dim=dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ drop=drop_rate,
+ act_layer=act,
+ drop_path=drop_path[i],
+ norm_layer=norm_layer,
+ eps=eps,
+ ))
+ else:
+ self.blocks.append(
+ Block(
+ dim=dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=act,
+ attn_drop=attn_drop_rate,
+ drop_path=drop_path[i],
+ norm_layer=norm_layer,
+ eps=eps,
+ ))
+ if i == conv_block_num - 1 and mixer[-1] != 'Conv':
+ self.blocks.append(FlattenTranspose())
+ if downsample:
+ if mixer[-1] == 'Conv':
+ self.downsample = SubSample2D(dim, out_dim, stride=sub_k)
+ elif mixer[-1] == 'Global':
+ self.downsample = SubSample1D(dim, out_dim, stride=sub_k)
+ else:
+ self.downsample = IdentitySize()
+
+ def forward(self, x, sz):
+ for blk in self.blocks:
+ x = blk(x)
+ x, sz = self.downsample(x, sz)
+ return x, sz
+
+
+class ADDPosEmbed(nn.Module):
+
+ def __init__(self, feat_max_size=[8, 32], embed_dim=768):
+ super().__init__()
+ pos_embed = torch.zeros(
+ [1, feat_max_size[0] * feat_max_size[1], embed_dim],
+ dtype=torch.float32)
+ trunc_normal_(pos_embed, mean=0, std=0.02)
+ self.pos_embed = nn.Parameter(
+ pos_embed.transpose(1, 2).reshape(1, embed_dim, feat_max_size[0],
+ feat_max_size[1]),
+ requires_grad=True,
+ )
+
+ def forward(self, x):
+ sz = x.shape[2:]
+ x = x + self.pos_embed[:, :, :sz[0], :sz[1]]
+ return x
+
+
+class POPatchEmbed(nn.Module):
+ """Image to Patch Embedding."""
+
+ def __init__(
+ self,
+ in_channels=3,
+ feat_max_size=[8, 32],
+ embed_dim=768,
+ use_pos_embed=False,
+ flatten=False,
+ ):
+ super().__init__()
+ self.patch_embed = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=None,
+ ),
+ )
+ if use_pos_embed:
+ self.patch_embed.append(ADDPosEmbed(feat_max_size, embed_dim))
+ if flatten:
+ self.patch_embed.append(FlattenTranspose())
+
+ def forward(self, x):
+ sz = x.shape[2:]
+ x = self.patch_embed(x)
+ return x, [sz[0] // 4, sz[1] // 4]
+
+
+class LastStage(nn.Module):
+
+ def __init__(self, in_channels, out_channels, last_drop, out_char_num):
+ super().__init__()
+ self.last_conv = nn.Linear(
+ in_channels, out_channels,
+ bias=False) # self.num_features, self.out_channels, bias=False)
+ self.hardswish = nn.Hardswish()
+ self.dropout = nn.Dropout(p=last_drop)
+
+ def forward(self, x, sz):
+ x = x.reshape(-1, sz[0], sz[1], x.shape[-1])
+ x = x.mean(1)
+ x = self.last_conv(x)
+ x = self.hardswish(x)
+ x = self.dropout(x)
+ return x, [1, sz[1]]
+
+
+class Feat2D(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, sz):
+ # b, L c
+ # H W
+ C = x.shape[-1]
+ x = x.transpose(1, 2).reshape(-1, C, sz[0], sz[1])
+ return x, sz
+
+
+# class LastStage(nn.Module):
+# def __init__(self, in_channels, out_channels, last_drop, out_char_num):
+# super().__init__()
+# self.avg_pool = nn.AdaptiveAvgPool2d([1, out_char_num])
+# self.last_conv = nn.Conv2d(
+# in_channels=in_channels,
+# out_channels=out_channels,
+# kernel_size=1,
+# stride=1,
+# padding=0,
+# bias=False,
+# )
+# self.hardswish = nn.Hardswish()
+# self.dropout = nn.Dropout(p=last_drop)
+# def forward(self, x, sz):
+# # x = x.reshape(-1, sz[0], sz[1], x.shape[-1])
+# C = x.shape[-1]
+# x = self.avg_pool(x.transpose(1, 2).reshape(-1, C, sz[0], sz[1]))
+# x = self.last_conv(x)
+# sz = x.shape[2:]
+# x = self.hardswish(x)
+# x = self.dropout(x)
+# x = x.flatten(2).transpose(1, 2)
+# return x, sz
+
+
+class SVTRv2LNConv(nn.Module):
+
+ def __init__(self,
+ max_sz=[32, 128],
+ in_channels=3,
+ out_channels=192,
+ out_char_num=25,
+ depths=[3, 6, 3],
+ dims=[64, 128, 256],
+ mixer=[['Conv'] * 3, ['Conv'] * 3 + ['Global'] * 3,
+ ['Global'] * 3],
+ use_pos_embed=True,
+ sub_k=[[1, 1], [2, 1], [1, 1]],
+ num_heads=[2, 4, 8],
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ last_drop=0.1,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ act=nn.GELU,
+ last_stage=False,
+ feat2d=False,
+ eps=1e-6,
+ **kwargs):
+ super().__init__()
+ num_stages = len(depths)
+ self.num_features = dims[-1]
+
+ feat_max_size = [max_sz[0] // 4, max_sz[1] // 4]
+ self.pope = POPatchEmbed(
+ in_channels=in_channels,
+ feat_max_size=feat_max_size,
+ embed_dim=dims[0],
+ use_pos_embed=use_pos_embed,
+ flatten=mixer[0][0] != 'Conv',
+ )
+
+ dpr = np.linspace(0, drop_path_rate,
+ sum(depths)) # stochastic depth decay rule
+
+ self.stages = nn.ModuleList()
+ for i_stage in range(num_stages):
+ stage = SVTRStage(
+ dim=dims[i_stage],
+ out_dim=dims[i_stage + 1] if i_stage < num_stages - 1 else 0,
+ depth=depths[i_stage],
+ mixer=mixer[i_stage],
+ sub_k=sub_k[i_stage],
+ num_heads=num_heads[i_stage],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_stage]):sum(depths[:i_stage + 1])],
+ norm_layer=norm_layer,
+ act=act,
+ downsample=False if i_stage == num_stages - 1 else True,
+ eps=eps,
+ )
+ self.stages.append(stage)
+
+ self.out_channels = self.num_features
+ self.last_stage = last_stage
+ if last_stage:
+ self.out_channels = out_channels
+ self.stages.append(
+ LastStage(self.num_features, out_channels, last_drop,
+ out_char_num))
+ if feat2d:
+ self.stages.append(Feat2D())
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m: nn.Module):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, mean=0, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ if isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+ if isinstance(m, nn.Conv2d):
+ kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'patch_embed', 'downsample', 'pos_embed'}
+
+ def forward(self, x):
+ x, sz = self.pope(x)
+ for stage in self.stages:
+ x, sz = stage(x, sz)
+ return x
diff --git a/openrec/modeling/encoders/svtrv2_lnconv_two33.py b/openrec/modeling/encoders/svtrv2_lnconv_two33.py
new file mode 100644
index 0000000000000000000000000000000000000000..df8ea14a09f01bd47361e2dc23162f59a395038a
--- /dev/null
+++ b/openrec/modeling/encoders/svtrv2_lnconv_two33.py
@@ -0,0 +1,517 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_
+
+from openrec.modeling.common import DropPath, Identity, Mlp
+
+
+class ConvBNLayer(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=0,
+ bias=False,
+ groups=1,
+ act=nn.GELU,
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ bias=bias,
+ )
+ self.norm = nn.BatchNorm2d(out_channels)
+ self.act = act()
+
+ def forward(self, inputs):
+ out = self.conv(inputs)
+ out = self.norm(out)
+ out = self.act(out)
+ return out
+
+
+class Attention(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.dim = dim
+ self.head_dim = dim // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, _ = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
+ self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+ attn = q @ k.transpose(-2, -1) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+ x = x.transpose(1, 2).reshape(B, N, self.dim)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ eps=1e-6,
+ ):
+ super().__init__()
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.norm1 = norm_layer(dim, eps=eps)
+ self.mixer = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ self.norm2 = norm_layer(dim, eps=eps)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, x):
+ x = self.norm1(x + self.drop_path(self.mixer(x)))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ return x
+
+
+class FlattenBlockRe2D(Block):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ mlp_ratio=4,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0,
+ attn_drop=0,
+ drop_path=0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ eps=0.000001):
+ super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
+ attn_drop, drop_path, act_layer, norm_layer, eps)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ x = x.flatten(2).transpose(1, 2)
+ x = super().forward(x)
+ x = x.transpose(1, 2).reshape(B, C, H, W)
+ return x
+
+
+class ConvBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ eps=1e-6,
+ num_conv=2,
+ kernel_size=3,
+ ):
+ super().__init__()
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.norm1 = norm_layer(dim, eps=eps)
+ self.mixer = nn.Sequential(*[
+ nn.Conv2d(
+ dim, dim, kernel_size, 1, kernel_size // 2, groups=num_heads)
+ for i in range(num_conv)
+ ])
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ self.norm2 = norm_layer(dim, eps=eps)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, x):
+ C, H, W = x.shape[1:]
+ x = x + self.drop_path(self.mixer(x))
+ x = self.norm1(x.flatten(2).transpose(1, 2))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ x = x.transpose(1, 2).reshape(-1, C, H, W)
+ return x
+
+
+class FlattenTranspose(nn.Module):
+
+ def forward(self, x):
+ return x.flatten(2).transpose(1, 2)
+
+
+class SubSample2D(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride=[2, 1],
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1)
+ self.norm = nn.LayerNorm(out_channels)
+
+ def forward(self, x, sz):
+ # print(x.shape)
+ x = self.conv(x)
+ C, H, W = x.shape[1:]
+ x = self.norm(x.flatten(2).transpose(1, 2))
+ x = x.transpose(1, 2).reshape(-1, C, H, W)
+ return x, [H, W]
+
+
+class SubSample1D(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride=[2, 1],
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1)
+ self.norm = nn.LayerNorm(out_channels)
+
+ def forward(self, x, sz):
+ C = x.shape[-1]
+ x = x.transpose(1, 2).reshape(-1, C, sz[0], sz[1])
+ x = self.conv(x)
+ C, H, W = x.shape[1:]
+ x = self.norm(x.flatten(2).transpose(1, 2))
+ return x, [H, W]
+
+
+class IdentitySize(nn.Module):
+
+ def forward(self, x, sz):
+ return x, sz
+
+
+class SVTRStage(nn.Module):
+
+ def __init__(self,
+ dim=64,
+ out_dim=256,
+ depth=3,
+ mixer=['Local'] * 3,
+ kernel_sizes=[3] * 3,
+ sub_k=[2, 1],
+ num_heads=2,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path=[0.1] * 3,
+ norm_layer=nn.LayerNorm,
+ act=nn.GELU,
+ eps=1e-6,
+ num_conv=[2] * 3,
+ downsample=None,
+ **kwargs):
+ super().__init__()
+ self.dim = dim
+
+ self.blocks = nn.Sequential()
+ for i in range(depth):
+ if mixer[i] == 'Conv':
+ self.blocks.append(
+ ConvBlock(dim=dim,
+ kernel_size=kernel_sizes[i],
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ drop=drop_rate,
+ act_layer=act,
+ drop_path=drop_path[i],
+ norm_layer=norm_layer,
+ eps=eps,
+ num_conv=num_conv[i]))
+ else:
+ if mixer[i] == 'Global':
+ block = Block
+ elif mixer[i] == 'FGlobal':
+ block = Block
+ self.blocks.append(FlattenTranspose())
+ elif mixer[i] == 'FGlobalRe2D':
+ block = FlattenBlockRe2D
+ self.blocks.append(
+ block(
+ dim=dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=act,
+ attn_drop=attn_drop_rate,
+ drop_path=drop_path[i],
+ norm_layer=norm_layer,
+ eps=eps,
+ ))
+
+ if downsample:
+ if mixer[-1] == 'Conv' or mixer[-1] == 'FGlobalRe2D':
+ self.downsample = SubSample2D(dim, out_dim, stride=sub_k)
+ else:
+ self.downsample = SubSample1D(dim, out_dim, stride=sub_k)
+ else:
+ self.downsample = IdentitySize()
+
+ def forward(self, x, sz):
+ for blk in self.blocks:
+ x = blk(x)
+ x, sz = self.downsample(x, sz)
+ return x, sz
+
+
+class ADDPosEmbed(nn.Module):
+
+ def __init__(self, feat_max_size=[8, 32], embed_dim=768):
+ super().__init__()
+ pos_embed = torch.zeros(
+ [1, feat_max_size[0] * feat_max_size[1], embed_dim],
+ dtype=torch.float32)
+ trunc_normal_(pos_embed, mean=0, std=0.02)
+ self.pos_embed = nn.Parameter(
+ pos_embed.transpose(1, 2).reshape(1, embed_dim, feat_max_size[0],
+ feat_max_size[1]),
+ requires_grad=True,
+ )
+
+ def forward(self, x):
+ sz = x.shape[2:]
+ x = x + self.pos_embed[:, :, :sz[0], :sz[1]]
+ return x
+
+
+class POPatchEmbed(nn.Module):
+ """Image to Patch Embedding."""
+
+ def __init__(self,
+ in_channels=3,
+ feat_max_size=[8, 32],
+ embed_dim=768,
+ use_pos_embed=False,
+ flatten=False,
+ bias=False):
+ super().__init__()
+ self.patch_embed = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=bias,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=bias,
+ ),
+ )
+ if use_pos_embed:
+ self.patch_embed.append(ADDPosEmbed(feat_max_size, embed_dim))
+ if flatten:
+ self.patch_embed.append(FlattenTranspose())
+
+ def forward(self, x):
+ sz = x.shape[2:]
+ x = self.patch_embed(x)
+ return x, [sz[0] // 4, sz[1] // 4]
+
+
+class LastStage(nn.Module):
+
+ def __init__(self, in_channels, out_channels, last_drop, out_char_num=0):
+ super().__init__()
+ self.last_conv = nn.Linear(in_channels, out_channels, bias=False)
+ self.hardswish = nn.Hardswish()
+ self.dropout = nn.Dropout(p=last_drop)
+
+ def forward(self, x, sz):
+ x = x.reshape(-1, sz[0], sz[1], x.shape[-1])
+ x = x.mean(1)
+ x = self.last_conv(x)
+ x = self.hardswish(x)
+ x = self.dropout(x)
+ return x, [1, sz[1]]
+
+
+class Feat2D(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, sz):
+ C = x.shape[-1]
+ x = x.transpose(1, 2).reshape(-1, C, sz[0], sz[1])
+ return x, sz
+
+
+class SVTRv2LNConvTwo33(nn.Module):
+
+ def __init__(self,
+ max_sz=[32, 128],
+ in_channels=3,
+ out_channels=192,
+ depths=[3, 6, 3],
+ dims=[64, 128, 256],
+ mixer=[['Conv'] * 3, ['Conv'] * 3 + ['Global'] * 3,
+ ['Global'] * 3],
+ use_pos_embed=True,
+ sub_k=[[1, 1], [2, 1], [1, 1]],
+ num_heads=[2, 4, 8],
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ last_drop=0.1,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ act=nn.GELU,
+ last_stage=False,
+ feat2d=False,
+ eps=1e-6,
+ num_convs=[[2] * 3, [2] * 3 + [3] * 3, [3] * 3],
+ kernel_sizes=[[3] * 3, [3] * 3 + [3] * 3, [3] * 3],
+ pope_bias=False,
+ **kwargs):
+ super().__init__()
+ num_stages = len(depths)
+ self.num_features = dims[-1]
+
+ feat_max_size = [max_sz[0] // 4, max_sz[1] // 4]
+ self.pope = POPatchEmbed(in_channels=in_channels,
+ feat_max_size=feat_max_size,
+ embed_dim=dims[0],
+ use_pos_embed=use_pos_embed,
+ flatten=mixer[0][0] != 'Conv',
+ bias=pope_bias)
+
+ dpr = np.linspace(0, drop_path_rate,
+ sum(depths)) # stochastic depth decay rule
+
+ self.stages = nn.ModuleList()
+ for i_stage in range(num_stages):
+ stage = SVTRStage(
+ dim=dims[i_stage],
+ out_dim=dims[i_stage + 1] if i_stage < num_stages - 1 else 0,
+ depth=depths[i_stage],
+ mixer=mixer[i_stage],
+ kernel_sizes=kernel_sizes[i_stage]
+ if len(kernel_sizes[i_stage]) == len(mixer[i_stage]) else [3] *
+ len(mixer[i_stage]),
+ sub_k=sub_k[i_stage],
+ num_heads=num_heads[i_stage],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_stage]):sum(depths[:i_stage + 1])],
+ norm_layer=norm_layer,
+ act=act,
+ downsample=False if i_stage == num_stages - 1 else True,
+ eps=eps,
+ num_conv=num_convs[i_stage] if len(num_convs[i_stage]) == len(
+ mixer[i_stage]) else [2] * len(mixer[i_stage]),
+ )
+ self.stages.append(stage)
+
+ self.out_channels = self.num_features
+ self.last_stage = last_stage
+ if last_stage:
+ self.out_channels = out_channels
+ self.stages.append(
+ LastStage(self.num_features, out_channels, last_drop))
+ if feat2d:
+ self.stages.append(Feat2D())
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m: nn.Module):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, mean=0, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ if isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+ if isinstance(m, nn.Conv2d):
+ kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'patch_embed', 'downsample', 'pos_embed'}
+
+ def forward(self, x):
+ if len(x.shape) == 5:
+ x = x.flatten(0, 1)
+ x, sz = self.pope(x)
+ for stage in self.stages:
+ x, sz = stage(x, sz)
+ return x
diff --git a/openrec/modeling/encoders/vit.py b/openrec/modeling/encoders/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..a353974615f34b51a50e530333ead5024fb0bcc4
--- /dev/null
+++ b/openrec/modeling/encoders/vit.py
@@ -0,0 +1,120 @@
+import numpy as np
+import torch
+from torch import nn
+from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_
+
+from openrec.modeling.common import Block, PatchEmbed
+from openrec.modeling.encoders.svtrv2_lnconv import Feat2D, LastStage
+
+
+class ViT(nn.Module):
+
+ def __init__(
+ self,
+ img_size=[32, 128],
+ patch_size=[4, 8],
+ in_channels=3,
+ out_channels=256,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.0,
+ norm_layer=nn.LayerNorm,
+ act_layer=nn.GELU,
+ last_stage=False,
+ feat2d=False,
+ use_cls_token=False,
+ **kwargs,
+ ):
+ super().__init__()
+ self.img_size = img_size
+ self.embed_dim = embed_dim
+ self.out_channels = embed_dim
+ self.use_cls_token = use_cls_token
+ self.feat_sz = [
+ img_size[0] // patch_size[0], img_size[1] // patch_size[1]
+ ]
+
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_channels,
+ embed_dim)
+ num_patches = self.patch_embed.num_patches
+ if use_cls_token:
+ self.cls_token = nn.Parameter(
+ torch.zeros([1, 1, embed_dim], dtype=torch.float32),
+ requires_grad=True,
+ )
+ trunc_normal_(self.cls_token, mean=0, std=0.02)
+ self.pos_embed = nn.Parameter(
+ torch.zeros([1, num_patches + 1, embed_dim],
+ dtype=torch.float32),
+ requires_grad=True,
+ )
+ else:
+ self.pos_embed = nn.Parameter(
+ torch.zeros([1, num_patches, embed_dim], dtype=torch.float32),
+ requires_grad=True,
+ )
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = np.linspace(0, drop_path_rate, depth)
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=act_layer,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ ) for i in range(depth)
+ ])
+ self.norm = norm_layer(embed_dim)
+ self.last_stage = last_stage
+ self.feat2d = feat2d
+ if last_stage:
+ self.out_channels = out_channels
+ self.stages = LastStage(embed_dim, out_channels, last_drop=0.1)
+ if feat2d:
+ self.stages = Feat2D()
+ trunc_normal_(self.pos_embed, mean=0, std=0.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, mean=0, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ if isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+ if isinstance(m, nn.Conv2d):
+ kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed'}
+
+ def forward(self, x):
+ x = self.patch_embed(x)
+ if self.use_cls_token:
+ x = torch.concat([self.cls_token.tile([x.shape[0], 1, 1]), x], 1)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+ for blk in self.blocks:
+ x = blk(x)
+ x = self.norm(x)
+ if self.use_cls_token:
+ x = x[:, 1:, :]
+ if self.last_stage:
+ x, sz = self.stages(x, self.feat_sz)
+ if self.feat2d:
+ x, sz = self.stages(x, self.feat_sz)
+ return x
diff --git a/openrec/modeling/transforms/__init__.py b/openrec/modeling/transforms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..225a1fbc055f292c532e974d86daeb41962112b5
--- /dev/null
+++ b/openrec/modeling/transforms/__init__.py
@@ -0,0 +1,15 @@
+__all__ = ['build_transform']
+
+
+def build_transform(config):
+ from .aster_tps import Aster_TPS
+ from .moran import MORN
+ from .tps import TPS
+
+ support_dict = ['TPS', 'Aster_TPS', 'MORN']
+
+ module_name = config.pop('name')
+ assert module_name in support_dict, Exception(
+ 'transform only support {}'.format(support_dict))
+ module_class = eval(module_name)(**config)
+ return module_class
diff --git a/openrec/modeling/transforms/__pycache__/__init__.cpython-38.pyc b/openrec/modeling/transforms/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b96a28b045e334f6e44ca4ddd6bb7924ea94b8d
Binary files /dev/null and b/openrec/modeling/transforms/__pycache__/__init__.cpython-38.pyc differ
diff --git a/openrec/modeling/transforms/aster_tps.py b/openrec/modeling/transforms/aster_tps.py
new file mode 100644
index 0000000000000000000000000000000000000000..acd9baa22d511eec20e96fe1b085b6dc50d1ad1b
--- /dev/null
+++ b/openrec/modeling/transforms/aster_tps.py
@@ -0,0 +1,262 @@
+import itertools
+import math
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+def conv3x3_block(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding."""
+ conv_layer = nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ block = nn.Sequential(
+ conv_layer,
+ nn.BatchNorm2d(out_planes),
+ nn.ReLU(inplace=True),
+ )
+ return block
+
+
+class STNHead(nn.Module):
+
+ def __init__(self, in_planes, num_ctrlpoints, activation='none'):
+ super(STNHead, self).__init__()
+
+ self.in_planes = in_planes
+ self.num_ctrlpoints = num_ctrlpoints
+ self.activation = activation
+ self.stn_convnet = nn.Sequential(
+ conv3x3_block(in_planes, 32), # 32*64
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ conv3x3_block(32, 64), # 16*32
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ conv3x3_block(64, 128), # 8*16
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ conv3x3_block(128, 256), # 4*8
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ conv3x3_block(256, 256), # 2*4,
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ conv3x3_block(256, 256)) # 1*2
+
+ self.stn_fc1 = nn.Sequential(nn.Linear(2 * 256, 512),
+ nn.BatchNorm1d(512),
+ nn.ReLU(inplace=True))
+ self.stn_fc2 = nn.Linear(512, num_ctrlpoints * 2)
+
+ self.init_weights(self.stn_convnet)
+ self.init_weights(self.stn_fc1)
+ self.init_stn(self.stn_fc2)
+
+ def init_weights(self, module):
+ for m in module.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ m.weight.data.normal_(0, 0.001)
+ m.bias.data.zero_()
+
+ def init_stn(self, stn_fc2):
+ margin = 0.01
+ sampling_num_per_side = int(self.num_ctrlpoints / 2)
+ ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
+ ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
+ ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom],
+ axis=0).astype(np.float32)
+ if self.activation == 'none':
+ pass
+ elif self.activation == 'sigmoid':
+ ctrl_points = -np.log(1. / ctrl_points - 1.)
+ stn_fc2.weight.data.zero_()
+ stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1)
+
+ def forward(self, x):
+ x = self.stn_convnet(x)
+ batch_size, _, h, w = x.size()
+ x = x.view(batch_size, -1)
+ img_feat = self.stn_fc1(x)
+ x = self.stn_fc2(0.1 * img_feat)
+ if self.activation == 'sigmoid':
+ x = F.sigmoid(x)
+ x = x.view(-1, self.num_ctrlpoints, 2)
+ return x
+
+
+def grid_sample(input, grid, canvas=None):
+ output = F.grid_sample(input, grid)
+ if canvas is None:
+ return output
+ else:
+ input_mask = input.data.new(input.size()).fill_(1)
+ output_mask = F.grid_sample(input_mask, grid)
+ padded_output = output * output_mask + canvas * (1 - output_mask)
+ return padded_output
+
+
+# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
+def compute_partial_repr(input_points, control_points):
+ N = input_points.size(0)
+ M = control_points.size(0)
+ pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
+ # original implementation, very slow
+ # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
+ pairwise_diff_square = pairwise_diff * pairwise_diff
+ pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
+ 1]
+ repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
+ # fix numerical error for 0 * log(0), substitute all nan with 0
+ mask = repr_matrix != repr_matrix
+ repr_matrix.masked_fill_(mask, 0)
+ return repr_matrix
+
+
+# output_ctrl_pts are specified, according to our task.
+def build_output_control_points(num_control_points, margins):
+ margin_x, margin_y = margins
+ num_ctrl_pts_per_side = num_control_points // 2
+ ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
+ ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
+ ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ # ctrl_pts_top = ctrl_pts_top[1:-1,:]
+ # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
+ output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom],
+ axis=0)
+ output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
+ return output_ctrl_pts
+
+
+class TPSSpatialTransformer(nn.Module):
+
+ def __init__(
+ self,
+ output_image_size,
+ num_control_points,
+ margins,
+ ):
+ super(TPSSpatialTransformer, self).__init__()
+ self.output_image_size = output_image_size
+ self.num_control_points = num_control_points
+ self.margins = margins
+
+ self.target_height, self.target_width = output_image_size
+ target_control_points = build_output_control_points(
+ num_control_points, margins)
+ N = num_control_points
+ # N = N - 4
+
+ # create padded kernel matrix
+ forward_kernel = torch.zeros(N + 3, N + 3)
+ target_control_partial_repr = compute_partial_repr(
+ target_control_points, target_control_points)
+ forward_kernel[:N, :N].copy_(target_control_partial_repr)
+ forward_kernel[:N, -3].fill_(1)
+ forward_kernel[-3, :N].fill_(1)
+ forward_kernel[:N, -2:].copy_(target_control_points)
+ forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
+ # compute inverse matrix
+ inverse_kernel = torch.inverse(forward_kernel)
+
+ # create target cordinate matrix
+ HW = self.target_height * self.target_width
+ target_coordinate = list(
+ itertools.product(range(self.target_height),
+ range(self.target_width)))
+ target_coordinate = torch.Tensor(target_coordinate) # HW x 2
+ Y, X = target_coordinate.split(1, dim=1)
+ Y = Y / (self.target_height - 1)
+ X = X / (self.target_width - 1)
+ target_coordinate = torch.cat([X, Y],
+ dim=1) # convert from (y, x) to (x, y)
+ target_coordinate_partial_repr = compute_partial_repr(
+ target_coordinate, target_control_points)
+ target_coordinate_repr = torch.cat([
+ target_coordinate_partial_repr,
+ torch.ones(HW, 1), target_coordinate
+ ],
+ dim=1)
+
+ # register precomputed matrices
+ self.register_buffer('inverse_kernel', inverse_kernel)
+ self.register_buffer('padding_matrix', torch.zeros(3, 2))
+ self.register_buffer('target_coordinate_repr', target_coordinate_repr)
+ self.register_buffer('target_control_points', target_control_points)
+
+ def forward(self, input, source_control_points):
+ assert source_control_points.ndimension() == 3
+ assert source_control_points.size(1) == self.num_control_points
+ assert source_control_points.size(2) == 2
+ batch_size = source_control_points.size(0)
+
+ Y = torch.cat([
+ source_control_points,
+ self.padding_matrix.expand(batch_size, 3, 2)
+ ], 1)
+ mapping_matrix = torch.matmul(self.inverse_kernel, Y)
+ source_coordinate = torch.matmul(self.target_coordinate_repr,
+ mapping_matrix)
+
+ grid = source_coordinate.view(-1, self.target_height,
+ self.target_width, 2)
+ grid = torch.clamp(
+ grid, 0, 1) # the source_control_points may be out of [0, 1].
+ # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
+ grid = 2.0 * grid - 1.0
+ output_maps = grid_sample(input, grid, canvas=None)
+ return output_maps
+
+
+class Aster_TPS(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ tps_inputsize=[32, 64],
+ tps_outputsize=[32, 100],
+ num_control_points=20,
+ tps_margins=[0.05, 0.05],
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ #TODO
+ self.out_channels = in_channels
+ self.tps_inputsize = tps_inputsize
+ self.num_control_points = num_control_points
+
+ self.stn_head = STNHead(
+ in_planes=3,
+ num_ctrlpoints=num_control_points,
+ )
+
+ self.tps = TPSSpatialTransformer(
+ output_image_size=tps_outputsize,
+ num_control_points=num_control_points,
+ margins=tps_margins,
+ )
+
+ def forward(self, img):
+ stn_input = F.interpolate(img,
+ self.tps_inputsize,
+ mode='bilinear',
+ align_corners=True)
+
+ ctrl_points = self.stn_head(stn_input)
+
+ img = self.tps(img, ctrl_points)
+
+ return img
diff --git a/openrec/modeling/transforms/moran.py b/openrec/modeling/transforms/moran.py
new file mode 100644
index 0000000000000000000000000000000000000000..01d7276148e329338958d91ca0ec137afac59cc6
--- /dev/null
+++ b/openrec/modeling/transforms/moran.py
@@ -0,0 +1,136 @@
+"""This code is refer from:
+https://github.com/Canjie-Luo/MORAN_v2
+"""
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+class MORN(nn.Module):
+
+ def __init__(self, in_channels, target_shape=[32, 100], enhance=1):
+ super(MORN, self).__init__()
+ self.targetH = target_shape[0]
+ self.targetW = target_shape[1]
+ self.enhance = enhance
+ self.out_channels = in_channels
+ self.cnn = nn.Sequential(nn.MaxPool2d(2, 2),
+ nn.Conv2d(in_channels, 64, 3, 1, 1),
+ nn.BatchNorm2d(64), nn.ReLU(True),
+ nn.MaxPool2d(2,
+ 2), nn.Conv2d(64, 128, 3, 1, 1),
+ nn.BatchNorm2d(128), nn.ReLU(True),
+ nn.MaxPool2d(2,
+ 2), nn.Conv2d(128, 64, 3, 1, 1),
+ nn.BatchNorm2d(64), nn.ReLU(True),
+ nn.Conv2d(64, 16, 3, 1, 1),
+ nn.BatchNorm2d(16), nn.ReLU(True),
+ nn.Conv2d(16, 1, 3, 1, 1), nn.BatchNorm2d(1))
+ self.pool = nn.MaxPool2d(2, 1)
+ h_list = np.arange(self.targetH) * 2. / (self.targetH - 1) - 1
+ w_list = np.arange(self.targetW) * 2. / (self.targetW - 1) - 1
+ grid = np.meshgrid(w_list, h_list, indexing='ij')
+ grid = np.stack(grid, axis=-1)
+ grid = np.transpose(grid, (1, 0, 2))
+ grid = np.expand_dims(grid, 0)
+ self.grid = nn.Parameter(
+ torch.from_numpy(grid).float(),
+ requires_grad=False,
+ )
+
+ def forward(self, x):
+
+ bs = x.shape[0]
+ grid = self.grid.tile([bs, 1, 1, 1])
+ grid_x = self.grid[:, :, :, 0].unsqueeze(3).tile([bs, 1, 1, 1])
+ grid_y = self.grid[:, :, :, 1].unsqueeze(3).tile([bs, 1, 1, 1])
+ x_small = F.upsample(x,
+ size=(self.targetH, self.targetW),
+ mode='bilinear')
+
+ offsets = self.cnn(x_small)
+ offsets_posi = F.relu(offsets, inplace=False)
+ offsets_nega = F.relu(-offsets, inplace=False)
+ offsets_pool = self.pool(offsets_posi) - self.pool(offsets_nega)
+
+ offsets_grid = F.grid_sample(offsets_pool, grid)
+ offsets_grid = offsets_grid.permute(0, 2, 3, 1).contiguous()
+ offsets_x = torch.cat([grid_x, grid_y + offsets_grid], 3)
+ x_rectified = F.grid_sample(x, offsets_x)
+
+ for iteration in range(self.enhance):
+ offsets = self.cnn(x_rectified)
+
+ offsets_posi = F.relu(offsets, inplace=False)
+ offsets_nega = F.relu(-offsets, inplace=False)
+ offsets_pool = self.pool(offsets_posi) - self.pool(offsets_nega)
+
+ offsets_grid += F.grid_sample(offsets_pool,
+ grid).permute(0, 2, 3,
+ 1).contiguous()
+ offsets_x = torch.cat([grid_x, grid_y + offsets_grid], 3)
+ x_rectified = F.grid_sample(x, offsets_x)
+
+ # if debug:
+
+ # offsets_mean = torch.mean(offsets_grid.view(x.size(0), -1), 1)
+ # offsets_max, _ = torch.max(offsets_grid.view(x.size(0), -1), 1)
+ # offsets_min, _ = torch.min(offsets_grid.view(x.size(0), -1), 1)
+
+ # import matplotlib.pyplot as plt
+ # from colour import Color
+ # from torchvision import transforms
+ # import cv2
+
+ # alpha = 0.7
+ # density_range = 256
+ # color_map = np.empty([self.targetH, self.targetW, 3], dtype=int)
+ # cmap = plt.get_cmap("rainbow")
+ # blue = Color("blue")
+ # hex_colors = list(blue.range_to(Color("red"), density_range))
+ # rgb_colors = [[rgb * 255 for rgb in color.rgb] for color in hex_colors][::-1]
+ # to_pil_image = transforms.ToPILImage()
+
+ # for i in range(x.size(0)):
+
+ # img_small = x_small[i].data.cpu().mul_(0.5).add_(0.5)
+ # img = to_pil_image(img_small)
+ # img = np.array(img)
+ # if len(img.shape) == 2:
+ # img = cv2.merge([img.copy()]*3)
+ # img_copy = img.copy()
+
+ # v_max = offsets_max.data[i]
+ # v_min = offsets_min.data[i]
+ # if self.cuda:
+ # img_offsets = (offsets_grid[i]).view(1, self.targetH, self.targetW).data.cuda().add_(-v_min).mul_(1./(v_max-v_min))
+ # else:
+ # img_offsets = (offsets_grid[i]).view(1, self.targetH, self.targetW).data.cpu().add_(-v_min).mul_(1./(v_max-v_min))
+ # img_offsets = to_pil_image(img_offsets)
+ # img_offsets = np.array(img_offsets)
+ # color_map = np.empty([self.targetH, self.targetW, 3], dtype=int)
+ # for h_i in range(self.targetH):
+ # for w_i in range(self.targetW):
+ # color_map[h_i][w_i] = rgb_colors[int(img_offsets[h_i, w_i]/256.*density_range)]
+ # color_map = color_map.astype(np.uint8)
+ # cv2.addWeighted(color_map, alpha, img_copy, 1-alpha, 0, img_copy)
+
+ # img_processed = x_rectified[i].data.cpu().mul_(0.5).add_(0.5)
+ # img_processed = to_pil_image(img_processed)
+ # img_processed = np.array(img_processed)
+ # if len(img_processed.shape) == 2:
+ # img_processed = cv2.merge([img_processed.copy()]*3)
+
+ # total_img = np.ones([self.targetH, self.targetW*3+10, 3], dtype=int)*255
+ # total_img[0:self.targetH, 0:self.targetW] = img
+ # total_img[0:self.targetH, self.targetW+5:2*self.targetW+5] = img_copy
+ # total_img[0:self.targetH, self.targetW*2+10:3*self.targetW+10] = img_processed
+ # total_img = cv2.resize(total_img.astype(np.uint8), (300, 50))
+ # # cv2.imshow("Input_Offsets_Output", total_img)
+ # # cv2.waitKey()
+
+ # return x_rectified, total_img
+
+ return x_rectified
diff --git a/openrec/modeling/transforms/tps.py b/openrec/modeling/transforms/tps.py
new file mode 100644
index 0000000000000000000000000000000000000000..5db2cc0b94485b2217da4704e310814078c49ad8
--- /dev/null
+++ b/openrec/modeling/transforms/tps.py
@@ -0,0 +1,246 @@
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from openrec.modeling.common import Activation
+
+
+class ConvBNLayer(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ act=None):
+ super(ConvBNLayer, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ bias=False,
+ )
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.act = Activation(act) if act else None
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ if self.act is not None:
+ x = self.act(x)
+ return x
+
+
+class LocalizationNetwork(nn.Module):
+
+ def __init__(self, in_channels, num_fiducial, loc_lr, model_name):
+ super(LocalizationNetwork, self).__init__()
+ self.F = num_fiducial
+ F = num_fiducial
+ if model_name == 'large':
+ num_filters_list = [64, 128, 256, 512]
+ fc_dim = 256
+ else:
+ num_filters_list = [16, 32, 64, 128]
+ fc_dim = 64
+
+ self.block_list = nn.ModuleList()
+ for fno in range(0, len(num_filters_list)):
+ num_filters = num_filters_list[fno]
+ conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=num_filters,
+ kernel_size=3,
+ act='relu',
+ )
+ self.block_list.append(conv)
+ if fno == len(num_filters_list) - 1:
+ pool = nn.AdaptiveAvgPool2d(1)
+ else:
+ pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+ in_channels = num_filters
+ self.block_list.append(pool)
+ self.fc1 = nn.Linear(in_channels, fc_dim)
+
+ # Init fc2 in LocalizationNetwork
+ self.fc2 = nn.Linear(fc_dim, F * 2)
+
+ initial_bias = self.get_initial_fiducials()
+ initial_bias = initial_bias.reshape(-1)
+ self.fc2.bias.data = torch.tensor(initial_bias, dtype=torch.float32)
+ nn.init.zeros_(self.fc2.weight.data)
+ self.out_channels = F * 2
+
+ def forward(self, x):
+ """
+ Estimating parameters of geometric transformation
+ Args:
+ image: input
+ Return:
+ batch_C_prime: the matrix of the geometric transformation
+ """
+ for block in self.block_list:
+ x = block(x)
+ x = x.squeeze(dim=2).squeeze(dim=2)
+ x = self.fc1(x)
+
+ x = F.relu(x)
+ x = self.fc2(x)
+ x = x.reshape(shape=[-1, self.F, 2])
+ return x
+
+ def get_initial_fiducials(self):
+ """see RARE paper Fig.
+
+ 6 (a)
+ """
+ F = self.F
+ ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
+ ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
+ ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
+ return initial_bias
+
+
+class GridGenerator(nn.Module):
+
+ def __init__(self, in_channels, num_fiducial):
+ super(GridGenerator, self).__init__()
+ self.eps = 1e-6
+ self.F = num_fiducial
+
+ self.fc = nn.Linear(in_channels, 6)
+ nn.init.constant_(self.fc.weight, 0)
+ nn.init.constant_(self.fc.bias, 0)
+ self.fc.weight.requires_grad = False
+ self.fc.bias.requires_grad = False
+
+ def forward(self, batch_C_prime, I_r_size):
+ """Generate the grid for the grid_sampler.
+
+ Args:
+ batch_C_prime: the matrix of the geometric transformation
+ I_r_size: the shape of the input image
+ Return:
+ batch_P_prime: the grid for the grid_sampler
+ """
+ C = self.build_C_paddle()
+ P = self.build_P_paddle(I_r_size)
+
+ inv_delta_C_tensor = self.build_inv_delta_C_paddle(C).float()
+ P_hat_tensor = self.build_P_hat_paddle(C, torch.tensor(P)).float()
+
+ batch_C_ex_part_tensor = self.get_expand_tensor(batch_C_prime)
+
+ batch_C_prime_with_zeros = torch.cat(
+ [batch_C_prime, batch_C_ex_part_tensor], dim=1)
+ batch_T = torch.matmul(
+ inv_delta_C_tensor.to(batch_C_prime_with_zeros.device),
+ batch_C_prime_with_zeros,
+ )
+ batch_P_prime = torch.matmul(P_hat_tensor.to(batch_T.device), batch_T)
+ return batch_P_prime
+
+ def build_C_paddle(self):
+ """Return coordinates of fiducial points in I_r; C."""
+ F = self.F
+ ctrl_pts_x = torch.linspace(-1.0, 1.0, int(F / 2), dtype=torch.float64)
+ ctrl_pts_y_top = -1 * torch.ones([int(F / 2)], dtype=torch.float64)
+ ctrl_pts_y_bottom = torch.ones([int(F / 2)], dtype=torch.float64)
+ ctrl_pts_top = torch.stack([ctrl_pts_x, ctrl_pts_y_top], dim=1)
+ ctrl_pts_bottom = torch.stack([ctrl_pts_x, ctrl_pts_y_bottom], dim=1)
+ C = torch.cat([ctrl_pts_top, ctrl_pts_bottom], dim=0)
+ return C # F x 2
+
+ def build_P_paddle(self, I_r_size):
+ I_r_height, I_r_width = I_r_size
+ I_r_grid_x = (torch.arange(-I_r_width, I_r_width, 2) +
+ 1.0) / torch.tensor(np.array([I_r_width]))
+
+ I_r_grid_y = (torch.arange(-I_r_height, I_r_height, 2) +
+ 1.0) / torch.tensor(np.array([I_r_height]))
+
+ # P: self.I_r_width x self.I_r_height x 2
+ P = torch.stack(torch.meshgrid(I_r_grid_x, I_r_grid_y), dim=2)
+ P = torch.permute(P, [1, 0, 2])
+ # n (= self.I_r_width x self.I_r_height) x 2
+ return P.reshape([-1, 2])
+
+ def build_inv_delta_C_paddle(self, C):
+ """Return inv_delta_C which is needed to calculate T."""
+ F = self.F
+ hat_eye = torch.eye(F) # F x F
+ hat_C = torch.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]),
+ dim=2) + hat_eye
+ hat_C = (hat_C**2) * torch.log(hat_C)
+ delta_C = torch.cat( # F+3 x F+3
+ [
+ torch.cat([torch.ones((F, 1)), C, hat_C], dim=1), # F x F+3
+ torch.concat([torch.zeros(
+ (2, 3)), C.transpose(0, 1)], dim=1), # 2 x F+3
+ torch.concat([torch.zeros(
+ (1, 3)), torch.ones((1, F))], dim=1), # 1 x F+3
+ ],
+ axis=0,
+ )
+ inv_delta_C = torch.inverse(delta_C)
+ return inv_delta_C # F+3 x F+3
+
+ def build_P_hat_paddle(self, C, P):
+ F = self.F
+ eps = self.eps
+ n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
+ # P_tile: n x 2 -> n x 1 x 2 -> n x F x 2
+ P_tile = torch.tile(torch.unsqueeze(P, dim=1), (1, F, 1))
+ C_tile = torch.unsqueeze(C, dim=0) # 1 x F x 2
+ P_diff = P_tile - C_tile # n x F x 2
+ # rbf_norm: n x F
+ rbf_norm = torch.norm(P_diff, p=2, dim=2, keepdim=False)
+
+ # rbf: n x F
+ rbf = torch.multiply(torch.square(rbf_norm), torch.log(rbf_norm + eps))
+ P_hat = torch.cat([torch.ones((n, 1)), P, rbf], dim=1)
+ return P_hat # n x F+3
+
+ def get_expand_tensor(self, batch_C_prime):
+ B, H, C = batch_C_prime.shape
+ batch_C_prime = batch_C_prime.reshape([B, H * C])
+ batch_C_ex_part_tensor = self.fc(batch_C_prime)
+ batch_C_ex_part_tensor = batch_C_ex_part_tensor.reshape([-1, 3, 2])
+ return batch_C_ex_part_tensor
+
+
+class TPS(nn.Module):
+
+ def __init__(self, in_channels, num_fiducial, loc_lr, model_name):
+ super(TPS, self).__init__()
+ self.loc_net = LocalizationNetwork(in_channels, num_fiducial, loc_lr,
+ model_name)
+ self.grid_generator = GridGenerator(self.loc_net.out_channels,
+ num_fiducial)
+ self.out_channels = in_channels
+
+ def forward(self, image):
+ image.stop_gradient = False
+ batch_C_prime = self.loc_net(image)
+ batch_P_prime = self.grid_generator(batch_C_prime, image.shape[2:])
+ batch_P_prime = batch_P_prime.reshape(
+ [-1, image.shape[2], image.shape[3], 2])
+ is_fp16 = False
+ if batch_P_prime.dtype != torch.float32:
+ data_type = batch_P_prime.dtype
+ image = image.float()
+ batch_P_prime = batch_P_prime.float()
+ is_fp16 = True
+ batch_I_r = F.grid_sample(image, grid=batch_P_prime)
+ if is_fp16:
+ batch_I_r = batch_I_r.astype(data_type)
+
+ return batch_I_r
diff --git a/openrec/optimizer/__init__.py b/openrec/optimizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9101e4eee8ed7c2877a1d18cbb9db8b53836f9c
--- /dev/null
+++ b/openrec/optimizer/__init__.py
@@ -0,0 +1,73 @@
+import copy
+
+import torch
+from torch import nn
+
+__all__ = ['build_optimizer']
+
+
+def param_groups_weight_decay(model: nn.Module,
+ weight_decay=1e-5,
+ no_weight_decay_list=()):
+ no_weight_decay_list = set(no_weight_decay_list)
+ decay = []
+ no_decay = []
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue
+
+ if param.ndim <= 1 or name.endswith(
+ '.bias') or any(nd in name for nd in no_weight_decay_list):
+ no_decay.append(param)
+ else:
+ decay.append(param)
+
+ return [
+ {
+ 'params': no_decay,
+ 'weight_decay': 0.0
+ },
+ {
+ 'params': decay,
+ 'weight_decay': weight_decay
+ },
+ ]
+
+
+def build_optimizer(optim_config, lr_scheduler_config, epochs, step_each_epoch,
+ model):
+ from . import lr
+
+ config = copy.deepcopy(optim_config)
+
+ if isinstance(model, nn.Module):
+ # a model was passed in, extract parameters and add weight decays to appropriate layers
+ weight_decay = config.get('weight_decay', 0.0)
+ filter_bias_and_bn = (config.pop('filter_bias_and_bn')
+ if 'filter_bias_and_bn' in config else False)
+ if weight_decay > 0.0 and filter_bias_and_bn:
+ no_weight_decay = {}
+ if hasattr(model, 'no_weight_decay'):
+ no_weight_decay = model.no_weight_decay()
+ parameters = param_groups_weight_decay(model, weight_decay,
+ no_weight_decay)
+ config['weight_decay'] = 0.0
+ # print('debug adamw')
+ else:
+ parameters = model.parameters()
+ else:
+ # iterable of parameters or param groups passed in
+ parameters = model
+
+ optim = getattr(torch.optim, config.pop('name'))(params=parameters,
+ **config)
+
+ lr_config = copy.deepcopy(lr_scheduler_config)
+ lr_config.update({
+ 'epochs': epochs,
+ 'step_each_epoch': step_each_epoch,
+ 'lr': config['lr']
+ })
+ lr_scheduler = getattr(lr,
+ lr_config.pop('name'))(**lr_config)(optimizer=optim)
+ return optim, lr_scheduler
diff --git a/openrec/optimizer/__pycache__/__init__.cpython-38.pyc b/openrec/optimizer/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35fd9c8e0e5b5cf3065a37a8488e21d43e23145d
Binary files /dev/null and b/openrec/optimizer/__pycache__/__init__.cpython-38.pyc differ
diff --git a/openrec/optimizer/lr.py b/openrec/optimizer/lr.py
new file mode 100644
index 0000000000000000000000000000000000000000..759a946665437c7565f120ca8ad422867f7b7b35
--- /dev/null
+++ b/openrec/optimizer/lr.py
@@ -0,0 +1,227 @@
+import math
+from functools import partial
+
+import numpy as np
+from torch.optim import lr_scheduler
+
+
+class StepLR(object):
+
+ def __init__(self,
+ step_each_epoch,
+ step_size,
+ warmup_epoch=0,
+ gamma=0.1,
+ last_epoch=-1,
+ **kwargs):
+ super(StepLR, self).__init__()
+ self.step_size = step_each_epoch * step_size
+ self.gamma = gamma
+ self.last_epoch = last_epoch
+ self.warmup_epoch = warmup_epoch
+
+ def __call__(self, optimizer):
+ return lr_scheduler.LambdaLR(optimizer, self.lambda_func,
+ self.last_epoch)
+
+ def lambda_func(self, current_step):
+ if current_step < self.warmup_epoch:
+ return float(current_step) / float(max(1, self.warmup_epoch))
+ return self.gamma**(current_step // self.step_size)
+
+
+class MultiStepLR(object):
+
+ def __init__(self,
+ step_each_epoch,
+ milestones,
+ warmup_epoch=0,
+ gamma=0.1,
+ last_epoch=-1,
+ **kwargs):
+ super(MultiStepLR, self).__init__()
+ self.milestones = [step_each_epoch * e for e in milestones]
+ self.gamma = gamma
+ self.last_epoch = last_epoch
+ self.warmup_epoch = warmup_epoch
+
+ def __call__(self, optimizer):
+ return lr_scheduler.LambdaLR(optimizer, self.lambda_func,
+ self.last_epoch)
+
+ def lambda_func(self, current_step):
+ if current_step < self.warmup_epoch:
+ return float(current_step) / float(max(1, self.warmup_epoch))
+ return self.gamma**len(
+ [m for m in self.milestones if m <= current_step])
+
+
+class ConstLR(object):
+
+ def __init__(self,
+ step_each_epoch,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs):
+ super(ConstLR, self).__init__()
+ self.last_epoch = last_epoch
+ self.warmup_epoch = warmup_epoch * step_each_epoch
+
+ def __call__(self, optimizer):
+ return lr_scheduler.LambdaLR(optimizer, self.lambda_func,
+ self.last_epoch)
+
+ def lambda_func(self, current_step):
+ if current_step < self.warmup_epoch:
+ return float(current_step) / float(max(1.0, self.warmup_epoch))
+ return 1.0
+
+
+class LinearLR(object):
+
+ def __init__(self,
+ epochs,
+ step_each_epoch,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs):
+ super(LinearLR, self).__init__()
+ self.epochs = epochs * step_each_epoch
+ self.last_epoch = last_epoch
+ self.warmup_epoch = warmup_epoch * step_each_epoch
+
+ def __call__(self, optimizer):
+ return lr_scheduler.LambdaLR(optimizer, self.lambda_func,
+ self.last_epoch)
+
+ def lambda_func(self, current_step):
+ if current_step < self.warmup_epoch:
+ return float(current_step) / float(max(1, self.warmup_epoch))
+ return max(
+ 0.0,
+ float(self.epochs - current_step) /
+ float(max(1, self.epochs - self.warmup_epoch)),
+ )
+
+
+class CosineAnnealingLR(object):
+
+ def __init__(self,
+ epochs,
+ step_each_epoch,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs):
+ super(CosineAnnealingLR, self).__init__()
+ self.epochs = epochs * step_each_epoch
+ self.last_epoch = last_epoch
+ self.warmup_epoch = warmup_epoch * step_each_epoch
+
+ def __call__(self, optimizer):
+ return lr_scheduler.LambdaLR(optimizer, self.lambda_func,
+ self.last_epoch)
+
+ def lambda_func(self, current_step, num_cycles=0.5):
+ if current_step < self.warmup_epoch:
+ return float(current_step) / float(max(1, self.warmup_epoch))
+ progress = float(current_step - self.warmup_epoch) / float(
+ max(1, self.epochs - self.warmup_epoch))
+ return max(
+ 0.0, 0.5 *
+ (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+
+class OneCycleLR(object):
+
+ def __init__(self,
+ epochs,
+ step_each_epoch,
+ last_epoch=-1,
+ lr=0.00148,
+ warmup_epoch=1.0,
+ cycle_momentum=True,
+ **kwargs):
+ super(OneCycleLR, self).__init__()
+ self.epochs = epochs
+ self.last_epoch = last_epoch
+ self.step_each_epoch = step_each_epoch
+ self.lr = lr
+ self.pct_start = warmup_epoch / epochs
+ self.cycle_momentum = cycle_momentum
+
+ def __call__(self, optimizer):
+ return lr_scheduler.OneCycleLR(
+ optimizer,
+ max_lr=self.lr,
+ total_steps=self.epochs * self.step_each_epoch,
+ pct_start=self.pct_start,
+ cycle_momentum=self.cycle_momentum,
+ )
+
+
+class PolynomialLR(object):
+
+ def __init__(self,
+ step_each_epoch,
+ epochs,
+ lr_end=1e-7,
+ power=1.0,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs):
+ super(PolynomialLR, self).__init__()
+ self.lr_end = lr_end
+ self.power = power
+ self.epochs = epochs * step_each_epoch
+ self.warmup_epoch = warmup_epoch * step_each_epoch
+ self.last_epoch = last_epoch
+
+ def __call__(self, optimizer):
+ lr_lambda = partial(
+ self.lambda_func,
+ lr_init=optimizer.defaults['lr'],
+ )
+ return lr_scheduler.LambdaLR(optimizer, lr_lambda, self.last_epoch)
+
+ def lambda_func(self, current_step, lr_init):
+ if current_step < self.warmup_epoch:
+ return float(current_step) / float(max(1, self.warmup_epoch))
+ elif current_step > self.epochs:
+ return self.lr_end / lr_init # as LambdaLR multiplies by lr_init
+ else:
+ lr_range = lr_init - self.lr_end
+ decay_steps = self.epochs - self.warmup_epoch
+ pct_remaining = 1 - (current_step -
+ self.warmup_epoch) / decay_steps
+ decay = lr_range * pct_remaining**self.power + self.lr_end
+ return decay / lr_init # as LambdaLR multiplies by lr_init
+
+
+class CdistNetLR(object):
+
+ def __init__(self,
+ step_each_epoch,
+ lr=0.0442,
+ n_warmup_steps=10000,
+ step2_epoch=7,
+ last_epoch=-1,
+ **kwargs):
+ super(CdistNetLR, self).__init__()
+ self.last_epoch = last_epoch
+ self.step2_epoch = step2_epoch * step_each_epoch
+ self.n_current_steps = 0
+ self.n_warmup_steps = n_warmup_steps
+ self.init_lr = lr
+ self.step2_lr = 0.00001
+
+ def __call__(self, optimizer):
+ return lr_scheduler.LambdaLR(optimizer, self.lambda_func,
+ self.last_epoch)
+
+ def lambda_func(self, current_step):
+ if current_step < self.step2_epoch:
+ return np.min([
+ np.power(current_step, -0.5),
+ np.power(self.n_warmup_steps, -1.5) * current_step,
+ ])
+ return self.step2_lr / self.init_lr
diff --git a/openrec/postprocess/__init__.py b/openrec/postprocess/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e12a57df1d7fd70bcf5e947cc58c70557e458898
--- /dev/null
+++ b/openrec/postprocess/__init__.py
@@ -0,0 +1,72 @@
+import copy
+
+__all__ = ['build_post_process']
+
+from .abinet_postprocess import ABINetLabelDecode
+from .ar_postprocess import ARLabelDecode
+from .ce_postprocess import CELabelDecode
+from .char_postprocess import CharLabelDecode
+from .cppd_postprocess import CPPDLabelDecode
+from .ctc_postprocess import CTCLabelDecode
+from .igtr_postprocess import IGTRLabelDecode
+from .lister_postprocess import LISTERLabelDecode
+from .mgp_postprocess import MPGLabelDecode
+from .nrtr_postprocess import NRTRLabelDecode
+from .smtr_postprocess import SMTRLabelDecode
+from .srn_postprocess import SRNLabelDecode
+from .visionlan_postprocess import VisionLANLabelDecode
+
+support_dict = [
+ 'CTCLabelDecode', 'CharLabelDecode', 'CELabelDecode', 'CPPDLabelDecode',
+ 'NRTRLabelDecode', 'ABINetLabelDecode', 'ARLabelDecode', 'IGTRLabelDecode',
+ 'VisionLANLabelDecode', 'SMTRLabelDecode', 'SRNLabelDecode',
+ 'LISTERLabelDecode', 'GTCLabelDecode', 'MPGLabelDecode'
+]
+
+
+def build_post_process(config, global_config=None):
+ config = copy.deepcopy(config)
+ module_name = config.pop('name')
+ if global_config is not None:
+ config.update(global_config)
+ assert module_name in support_dict, Exception(
+ 'post process only support {}'.format(support_dict))
+ module_class = eval(module_name)(**config)
+ return module_class
+
+
+class GTCLabelDecode(object):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ gtc_label_decode=None,
+ character_dict_path=None,
+ use_space_char=True,
+ only_gtc=False,
+ with_ratio=False,
+ **kwargs):
+ gtc_label_decode['character_dict_path'] = character_dict_path
+ gtc_label_decode['use_space_char'] = use_space_char
+ self.gtc_label_decode = build_post_process(gtc_label_decode)
+ self.ctc_label_decode = CTCLabelDecode(
+ character_dict_path=character_dict_path,
+ use_space_char=use_space_char)
+ self.gtc_character = self.gtc_label_decode.character
+ self.ctc_character = self.ctc_label_decode.character
+ self.only_gtc = only_gtc
+ self.with_ratio = with_ratio
+
+ def get_character_num(self):
+ return [len(self.gtc_character), len(self.ctc_character)]
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+ if self.with_ratio:
+ batch = batch[:-1]
+ gtc = self.gtc_label_decode(preds['gtc_pred'],
+ batch[:-2] if batch is not None else None)
+ if self.only_gtc:
+ return gtc
+ ctc = self.ctc_label_decode(preds['ctc_pred'], [None] +
+ batch[-2:] if batch is not None else None)
+
+ return [gtc, ctc]
diff --git a/openrec/postprocess/__pycache__/__init__.cpython-38.pyc b/openrec/postprocess/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e800028f98ade1be95b4f46568795792a7c222c
Binary files /dev/null and b/openrec/postprocess/__pycache__/__init__.cpython-38.pyc differ
diff --git a/openrec/postprocess/__pycache__/abinet_postprocess.cpython-38.pyc b/openrec/postprocess/__pycache__/abinet_postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5341abdb6cc882cdc681860be565ff575257f76
Binary files /dev/null and b/openrec/postprocess/__pycache__/abinet_postprocess.cpython-38.pyc differ
diff --git a/openrec/postprocess/__pycache__/ar_postprocess.cpython-38.pyc b/openrec/postprocess/__pycache__/ar_postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8bb3db0313502fbc3397d31dd2ff2783abf3721
Binary files /dev/null and b/openrec/postprocess/__pycache__/ar_postprocess.cpython-38.pyc differ
diff --git a/openrec/postprocess/__pycache__/ce_postprocess.cpython-38.pyc b/openrec/postprocess/__pycache__/ce_postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1babf69ad7a5de191d85ac0390230a25a739bf42
Binary files /dev/null and b/openrec/postprocess/__pycache__/ce_postprocess.cpython-38.pyc differ
diff --git a/openrec/postprocess/__pycache__/char_postprocess.cpython-38.pyc b/openrec/postprocess/__pycache__/char_postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a2c4ccf20df9c12a1fcc197974b8049a494d459
Binary files /dev/null and b/openrec/postprocess/__pycache__/char_postprocess.cpython-38.pyc differ
diff --git a/openrec/postprocess/__pycache__/cppd_postprocess.cpython-38.pyc b/openrec/postprocess/__pycache__/cppd_postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3e65275a356a33eb1e42652a4ccc9a0792bda7c
Binary files /dev/null and b/openrec/postprocess/__pycache__/cppd_postprocess.cpython-38.pyc differ
diff --git a/openrec/postprocess/__pycache__/ctc_postprocess.cpython-38.pyc b/openrec/postprocess/__pycache__/ctc_postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad6936375e013be82443d54bcc5ece24ff71e636
Binary files /dev/null and b/openrec/postprocess/__pycache__/ctc_postprocess.cpython-38.pyc differ
diff --git a/openrec/postprocess/__pycache__/igtr_postprocess.cpython-38.pyc b/openrec/postprocess/__pycache__/igtr_postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab900c6bd96c60cbe3b06c7c315d70bd85b392d5
Binary files /dev/null and b/openrec/postprocess/__pycache__/igtr_postprocess.cpython-38.pyc differ
diff --git a/openrec/postprocess/__pycache__/lister_postprocess.cpython-38.pyc b/openrec/postprocess/__pycache__/lister_postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96ae205a0d75b42dcac7b887fa7fde7de6423a3d
Binary files /dev/null and b/openrec/postprocess/__pycache__/lister_postprocess.cpython-38.pyc differ
diff --git a/openrec/postprocess/__pycache__/mgp_postprocess.cpython-38.pyc b/openrec/postprocess/__pycache__/mgp_postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dadd532266eae2f5127ee476eea208d848e2db4f
Binary files /dev/null and b/openrec/postprocess/__pycache__/mgp_postprocess.cpython-38.pyc differ
diff --git a/openrec/postprocess/__pycache__/nrtr_postprocess.cpython-38.pyc b/openrec/postprocess/__pycache__/nrtr_postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..84e278ac033940eda966ca655c036da62555fd55
Binary files /dev/null and b/openrec/postprocess/__pycache__/nrtr_postprocess.cpython-38.pyc differ
diff --git a/openrec/postprocess/__pycache__/smtr_postprocess.cpython-38.pyc b/openrec/postprocess/__pycache__/smtr_postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42ac4abdf0a32f9e553befad360eeab54053cf0a
Binary files /dev/null and b/openrec/postprocess/__pycache__/smtr_postprocess.cpython-38.pyc differ
diff --git a/openrec/postprocess/__pycache__/srn_postprocess.cpython-38.pyc b/openrec/postprocess/__pycache__/srn_postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b48271d03094d1f33d1aee0e717731aba313bb1
Binary files /dev/null and b/openrec/postprocess/__pycache__/srn_postprocess.cpython-38.pyc differ
diff --git a/openrec/postprocess/__pycache__/visionlan_postprocess.cpython-38.pyc b/openrec/postprocess/__pycache__/visionlan_postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51e263cfe03d574bf7b1e13f9980083b8b0ed59b
Binary files /dev/null and b/openrec/postprocess/__pycache__/visionlan_postprocess.cpython-38.pyc differ
diff --git a/openrec/postprocess/abinet_postprocess.py b/openrec/postprocess/abinet_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cbae498ab31ce2175c3702b6641114f05596f6d
--- /dev/null
+++ b/openrec/postprocess/abinet_postprocess.py
@@ -0,0 +1,37 @@
+import torch
+
+from .nrtr_postprocess import NRTRLabelDecode
+
+
+class ABINetLabelDecode(NRTRLabelDecode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(ABINetLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+ if isinstance(preds, dict):
+ if len(preds['align']) > 0:
+ preds = preds['align'][-1].detach().cpu().numpy()
+ else:
+ preds = preds['vision'].detach().cpu().numpy()
+ elif isinstance(preds, torch.Tensor):
+ preds = preds.detach().cpu().numpy()
+ else:
+ preds = preds
+
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if batch is None:
+ return text
+ label = self.decode(batch[1].cpu().numpy())
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = [''] + dict_character
+ return dict_character
diff --git a/openrec/postprocess/ar_postprocess.py b/openrec/postprocess/ar_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..0045a7c01b8e712dc4544f197aaa5931b200c7ca
--- /dev/null
+++ b/openrec/postprocess/ar_postprocess.py
@@ -0,0 +1,63 @@
+import numpy as np
+import torch
+
+from .ctc_postprocess import BaseRecLabelDecode
+
+
+class ARLabelDecode(BaseRecLabelDecode):
+ """Convert between text-label and text-index."""
+
+ BOS = ''
+ EOS = ''
+ PAD = ''
+
+ def __init__(self,
+ character_dict_path=None,
+ use_space_char=True,
+ **kwargs):
+ super(ARLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+
+ if isinstance(preds, list):
+ preds = preds[-1]
+ if isinstance(preds, torch.Tensor):
+ preds = preds.detach().cpu().numpy()
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob)
+ if batch is None:
+ return text
+ label = batch[1]
+ label = self.decode(label[:, 1:].detach().cpu().numpy())
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
+ return dict_character
+
+ def decode(self, text_index, text_prob=None):
+ """convert text-index into text-label."""
+ result_list = []
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ try:
+ char_idx = self.character[int(text_index[batch_idx][idx])]
+ except:
+ continue
+ if char_idx == self.EOS: # end
+ break
+ if char_idx == self.BOS or char_idx == self.PAD:
+ continue
+ char_list.append(char_idx)
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
diff --git a/openrec/postprocess/ce_postprocess.py b/openrec/postprocess/ce_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c4b5623f8d0485beefcfe71881b8efdc91956c7
--- /dev/null
+++ b/openrec/postprocess/ce_postprocess.py
@@ -0,0 +1,43 @@
+import torch
+
+from .ctc_postprocess import BaseRecLabelDecode
+
+
+class CELabelDecode(BaseRecLabelDecode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(CELabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ if isinstance(preds, tuple) or isinstance(preds, list):
+ preds = preds[-1]
+ if isinstance(preds, torch.Tensor):
+ preds = preds.numpy()
+ preds_idx = preds.argmax(axis=-1)
+ preds_prob = preds.max(axis=-1)
+ text = self.decode(preds_idx, preds_prob)
+ if label is None:
+ return text
+ label = self.decode(label.flatten())
+ return text, label
+
+ def decode(self, text_index, text_prob=None):
+ """convert text-index into text-label."""
+ result_list = []
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ text = self.character[text_index[batch_idx]]
+ if text_prob is not None:
+ conf_list = text_prob[batch_idx]
+ else:
+ conf_list = 1.0
+ result_list.append((text, conf_list))
+ return result_list
+
+ def add_special_char(self, dict_character):
+ return dict_character
diff --git a/openrec/postprocess/char_postprocess.py b/openrec/postprocess/char_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa569003a0d9e818e95f7698076716153eadb51e
--- /dev/null
+++ b/openrec/postprocess/char_postprocess.py
@@ -0,0 +1,108 @@
+import numpy as np
+import torch
+
+from .ctc_postprocess import BaseRecLabelDecode
+
+
+class CharLabelDecode(BaseRecLabelDecode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ character_dict_path=None,
+ use_space_char=True,
+ **kwargs):
+ super(CharLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ if len(preds) >= 4:
+ preds_id = preds[0]
+ preds_prob = preds[1]
+ char_preds = preds[2]
+ if isinstance(preds_id, torch.Tensor):
+ preds_id = preds_id.numpy()
+ if isinstance(preds_prob, torch.Tensor):
+ preds_prob = preds_prob.numpy()
+ if preds_id[0][0] == 2:
+ preds_idx = preds_id[:, 1:]
+ preds_prob = preds_prob[:, 1:]
+ # char_preds = char_preds[:, 1:]
+ else:
+ preds_idx = preds_id
+ char_preds = char_preds.numpy()
+ char_preds_idx = char_preds.argmax(-1) + 4
+ char_preds_prob = char_preds.max(-1)
+ text, text_box = self.decode(preds_idx, preds_prob, char_preds_idx,
+ char_preds_prob)
+ else:
+ preds_logit = preds[0].numpy()
+ char_preds = preds[1].numpy()
+ # if isinstance(preds, torch.Tensor):
+ # preds = preds.numpy()
+ preds_idx = preds_logit.argmax(axis=2)
+ preds_prob = preds_logit.max(axis=2)
+ char_preds_idx = char_preds.argmax(-1) + 4
+ char_preds_prob = char_preds.max(-1)
+ text, text_box = self.decode(preds_idx, preds_prob, char_preds_idx,
+ char_preds_prob)
+
+ if label is None:
+ return text, text_box
+ label = self.decode(label[:, 1:])
+ return text, text_box, label
+
+ def add_special_char(self, dict_character):
+ dict_character = ['blank', '', '', ''] + dict_character
+ return dict_character
+
+ def decode(
+ self,
+ text_index,
+ text_prob=None,
+ char_text_index=None,
+ char_text_prob=None,
+ is_remove_duplicate=False,
+ ):
+ """convert text-index into text-label."""
+ result_list = []
+ box_result_list = []
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ char_box_list = []
+ conf_box_list = []
+ for idx in range(len(text_index[batch_idx])):
+ try:
+ char_idx = self.character[int(text_index[batch_idx][idx])]
+ if char_text_index is not None:
+ char_box_idx = self.character[int(
+ char_text_index[batch_idx][idx])]
+ except:
+ continue
+ if char_idx == '': # end
+ break
+ char_list.append(char_idx)
+
+ if char_text_index is not None:
+ char_box_list.append(char_box_idx)
+
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+
+ if char_text_prob is not None:
+ conf_box_list.append(char_text_prob[batch_idx][idx])
+ else:
+ conf_box_list.append(1)
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list).tolist()))
+
+ if char_text_index is not None:
+ text_box = ''.join(char_box_list)
+ box_result_list.append(
+ (text_box, np.mean(conf_box_list).tolist()))
+ if char_text_index is not None:
+ return result_list, box_result_list
+ return result_list
diff --git a/openrec/postprocess/cppd_postprocess.py b/openrec/postprocess/cppd_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e044c3bd8f1603192190dbe69d4182de55be209
--- /dev/null
+++ b/openrec/postprocess/cppd_postprocess.py
@@ -0,0 +1,42 @@
+import torch
+
+from .nrtr_postprocess import NRTRLabelDecode
+
+
+class CPPDLabelDecode(NRTRLabelDecode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(CPPDLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+
+ if isinstance(preds, tuple):
+ if isinstance(preds[-1], dict):
+ preds = preds[-1]['align'][-1].detach().cpu().numpy()
+ else:
+ preds = preds[-1].detach().cpu().numpy()
+ if isinstance(preds, list):
+ preds = preds[-1].detach().cpu().numpy()
+ if isinstance(preds, torch.Tensor):
+ preds = preds.detach().cpu().numpy()
+ elif isinstance(preds, dict):
+ preds = preds['align'][-1].detach().cpu().numpy()
+ else:
+ preds = preds
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if batch is None:
+ return text
+ label = batch[1]
+ label = self.decode(label.detach().cpu().numpy())
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = [''] + dict_character
+ return dict_character
diff --git a/openrec/postprocess/ctc_postprocess.py b/openrec/postprocess/ctc_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd160130230e46fa669f506f992fc9235bc5412b
--- /dev/null
+++ b/openrec/postprocess/ctc_postprocess.py
@@ -0,0 +1,119 @@
+import re
+
+import numpy as np
+import torch
+
+
+class BaseRecLabelDecode(object):
+ """Convert between text-label and text-index."""
+
+ def __init__(self, character_dict_path=None, use_space_char=False):
+ self.beg_str = 'sos'
+ self.end_str = 'eos'
+ self.reverse = False
+ self.character_str = []
+
+ if character_dict_path is None:
+ self.character_str = '0123456789abcdefghijklmnopqrstuvwxyz'
+ dict_character = list(self.character_str)
+ else:
+ with open(character_dict_path, 'rb') as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip('\n').strip('\r\n')
+ self.character_str.append(line)
+ if use_space_char:
+ self.character_str.append(' ')
+ dict_character = list(self.character_str)
+ if 'arabic' in character_dict_path:
+ self.reverse = True
+
+ dict_character = self.add_special_char(dict_character)
+ self.dict = {}
+ for i, char in enumerate(dict_character):
+ self.dict[char] = i
+ self.character = dict_character
+
+ def pred_reverse(self, pred):
+ pred_re = []
+ c_current = ''
+ for c in pred:
+ if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
+ if c_current != '':
+ pred_re.append(c_current)
+ pred_re.append(c)
+ c_current = ''
+ else:
+ c_current += c
+ if c_current != '':
+ pred_re.append(c_current)
+
+ return ''.join(pred_re[::-1])
+
+ def add_special_char(self, dict_character):
+ return dict_character
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """convert text-index into text-label."""
+ result_list = []
+ ignored_tokens = self.get_ignored_tokens()
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ selection = np.ones(len(text_index[batch_idx]), dtype=bool)
+ if is_remove_duplicate:
+ selection[1:] = text_index[batch_idx][1:] != text_index[
+ batch_idx][:-1]
+ for ignored_token in ignored_tokens:
+ selection &= text_index[batch_idx] != ignored_token
+
+ char_list = [
+ self.character[text_id]
+ for text_id in text_index[batch_idx][selection]
+ ]
+ if text_prob is not None:
+ conf_list = text_prob[batch_idx][selection]
+ else:
+ conf_list = [1] * len(selection)
+ if len(conf_list) == 0:
+ conf_list = [0]
+
+ text = ''.join(char_list)
+
+ if self.reverse: # for arabic rec
+ text = self.pred_reverse(text)
+
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
+
+ def get_ignored_tokens(self):
+ return [0] # for ctc blank
+
+ def get_character_num(self):
+ return len(self.character)
+
+
+class CTCLabelDecode(BaseRecLabelDecode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(CTCLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+ # preds = preds['res']
+ if isinstance(preds, torch.Tensor):
+ preds = preds.detach().cpu().numpy()
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
+ if batch is None:
+ return text
+ label = self.decode(batch[1].cpu().numpy())
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = ['blank'] + dict_character
+ return dict_character
diff --git a/openrec/postprocess/igtr_postprocess.py b/openrec/postprocess/igtr_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fc12ca0b8a4b068bb37153face693167a323daa
--- /dev/null
+++ b/openrec/postprocess/igtr_postprocess.py
@@ -0,0 +1,100 @@
+import numpy as np
+import torch
+
+from .nrtr_postprocess import NRTRLabelDecode
+
+
+class IGTRLabelDecode(NRTRLabelDecode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(IGTRLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+
+ if isinstance(preds, list):
+ if isinstance(preds[0], dict):
+ preds = preds[-1].detach().cpu().numpy()
+ if isinstance(preds, torch.Tensor):
+ preds = preds.detach().cpu().numpy()
+ elif isinstance(preds, dict):
+ preds = preds['align'][-1].detach().cpu().numpy()
+ else:
+ preds = preds
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx,
+ preds_prob,
+ is_remove_duplicate=False)
+ else:
+ preds_idx = preds[0].detach().cpu().numpy()
+ preds_prob = preds[1].detach().cpu().numpy()
+ text = self.decode(preds_idx,
+ preds_prob,
+ is_remove_duplicate=False)
+ else:
+ if isinstance(preds, torch.Tensor):
+ preds = preds.detach().cpu().numpy()
+ elif isinstance(preds, dict):
+ preds = preds['align'][-1].detach().cpu().numpy()
+ else:
+ preds = preds
+ preds_idx = preds.argmax(axis=2)
+ preds_idx_top5 = preds.argsort(axis=2)[:, :, -5:]
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx,
+ preds_prob,
+ is_remove_duplicate=False,
+ idx_top5=preds_idx_top5)
+ if batch is None:
+ return text
+ label = batch[1]
+ label = self.decode(label.detach().cpu().numpy())
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = [''] + dict_character + ['', '']
+ return dict_character
+
+ def decode(self,
+ text_index,
+ text_prob=None,
+ is_remove_duplicate=False,
+ idx_top5=None):
+ """convert text-index into text-label."""
+ result_list = []
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ char_list_top5 = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ char_idx_top5 = []
+ try:
+ char_idx = self.character[int(text_index[batch_idx][idx])]
+ if idx_top5 is not None:
+ for top5_i in idx_top5[batch_idx][idx]:
+ char_idx_top5.append(self.character[top5_i])
+ except:
+ continue
+ if char_idx == '': # end
+ break
+ if char_idx == '' or char_idx == '':
+ continue
+ char_list.append(char_idx)
+ char_list_top5.append(char_idx_top5)
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ text = ''.join(char_list)
+ if idx_top5 is not None:
+ result_list.append(
+ (text, [np.mean(conf_list).tolist(), char_list_top5]))
+ else:
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
diff --git a/openrec/postprocess/lister_postprocess.py b/openrec/postprocess/lister_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b4750ab94f9eb2931482e74de4f4986ec5bc0b0
--- /dev/null
+++ b/openrec/postprocess/lister_postprocess.py
@@ -0,0 +1,59 @@
+import numpy as np
+import torch
+
+from openrec.postprocess.ctc_postprocess import BaseRecLabelDecode
+
+
+class LISTERLabelDecode(BaseRecLabelDecode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ character_dict_path=None,
+ use_space_char=True,
+ **kwargs):
+ super(LISTERLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+
+ preds = preds[1]['logits']
+ if isinstance(preds, torch.Tensor):
+ preds = preds.detach().cpu().numpy()
+ preds_idx = preds.argmax(axis=2)
+ # preds_idx_top5 = preds.argsort(axis=2)[:, :, -5:]
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if batch is None:
+ return text
+ label = batch[1]
+ label = self.decode(label.detach().cpu().numpy())
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = [''] + dict_character + ['']
+ return dict_character
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """convert text-index into text-label."""
+ result_list = []
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ try:
+ char_idx = self.character[int(text_index[batch_idx][idx])]
+ except:
+ continue
+ if char_idx == '': # end
+ break
+ if char_idx == '' or char_idx == '':
+ continue
+ char_list.append(char_idx)
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
diff --git a/openrec/postprocess/mgp_postprocess.py b/openrec/postprocess/mgp_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..8677fc9bf4d32cf098610ba79828b8c533ddc6a0
--- /dev/null
+++ b/openrec/postprocess/mgp_postprocess.py
@@ -0,0 +1,143 @@
+from .ctc_postprocess import BaseRecLabelDecode
+
+
+class MPGLabelDecode(BaseRecLabelDecode):
+ """Convert between text-label and text-index."""
+ SPACE = '[s]'
+ GO = '[GO]'
+ list_token = [GO, SPACE]
+
+ def __init__(self,
+ character_dict_path=None,
+ use_space_char=False,
+ only_char=False,
+ **kwargs):
+ super(MPGLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+ self.only_char = only_char
+ self.EOS = '[s]'
+ self.PAD = '[GO]'
+ if not only_char:
+ # transformers==4.2.1
+ from transformers import BertTokenizer, GPT2Tokenizer
+ self.bpe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ self.wp_tokenizer = BertTokenizer.from_pretrained(
+ 'bert-base-uncased')
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+
+ if isinstance(preds, list):
+ char_preds = preds[0].detach().cpu().numpy()
+ else:
+ char_preds = preds.detach().cpu().numpy()
+
+ preds_idx = char_preds.argmax(axis=2)
+ preds_prob = char_preds.max(axis=2)
+ char_text = self.char_decode(preds_idx[:, 1:], preds_prob[:, 1:])
+ if batch is None:
+ return char_text
+ label = batch[1]
+ label = self.char_decode(label[:, 1:].detach().cpu().numpy())
+ if self.only_char:
+ return char_text, label
+ else:
+ bpe_preds = preds[1].detach().cpu().numpy()
+ wp_preds = preds[2]
+
+ bpe_preds_idx = bpe_preds.argmax(axis=2)
+ bpe_preds_prob = bpe_preds.max(axis=2)
+ bpe_text = self.bpe_decode(bpe_preds_idx[:, 1:],
+ bpe_preds_prob[:, 1:])
+
+ wp_preds = wp_preds.detach() #.cpu().numpy()
+ wp_preds_prob, wp_preds_idx = wp_preds.max(-1)
+ wp_text = self.wp_decode(wp_preds_idx[:, 1:], wp_preds_prob[:, 1:])
+
+ final_text = self.final_decode(char_text, bpe_text, wp_text)
+ return char_text, bpe_text, wp_text, final_text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = self.list_token + dict_character
+ return dict_character
+
+ def final_decode(self, char_text, bpe_text, wp_text):
+ result_list = []
+ for (char_pred,
+ char_pred_conf), (bpe_pred,
+ bpe_pred_conf), (wp_pred, wp_pred_conf) in zip(
+ char_text, bpe_text, wp_text):
+ final_text = char_pred
+ final_prob = char_pred_conf
+ if bpe_pred_conf > final_prob:
+ final_text = bpe_pred
+ final_prob = bpe_pred_conf
+ if wp_pred_conf > final_prob:
+ final_text = wp_pred
+ final_prob = wp_pred_conf
+ result_list.append((final_text, final_prob))
+ return result_list
+
+ def char_decode(self, text_index, text_prob=None):
+ """ convert text-index into text-label. """
+ result_list = []
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = 1.0
+ for idx in range(len(text_index[batch_idx])):
+ try:
+ char_idx = self.character[int(text_index[batch_idx][idx])]
+ except:
+ continue
+ if text_prob is not None:
+ conf_list *= text_prob[batch_idx][idx]
+
+ if char_idx == self.EOS: # end
+ break
+ if char_idx == self.PAD:
+ continue
+ char_list.append(char_idx)
+
+ text = ''.join(char_list)
+ result_list.append((text, conf_list))
+ return result_list
+
+ def bpe_decode(self, text_index, text_prob):
+ """ convert text-index into text-label. """
+ result_list = []
+ for text, probs in zip(text_index, text_prob):
+ text_decoded = []
+ conf_list = 1.0
+ for bpeindx, prob in zip(text, probs):
+ tokenstr = self.bpe_tokenizer.decode([bpeindx])
+ if tokenstr == '#':
+ break
+ text_decoded.append(tokenstr)
+ conf_list *= prob
+ text = ''.join(text_decoded)
+ result_list.append((text, conf_list))
+ return result_list
+
+ def wp_decode(self, text_index, text_prob=None):
+ """ convert text-index into text-label. """
+ result_list = []
+ for batch_idx, text in enumerate(text_index):
+ wp_pred = self.wp_tokenizer.decode(text)
+ wp_pred_EOS = wp_pred.find('[SEP]')
+ wp_pred = wp_pred[:wp_pred_EOS]
+ if text_prob is not None:
+ try:
+ # print(text.cpu().tolist())
+ wp_pred_EOS_index = text.cpu().tolist().index(102) + 1
+ except:
+ wp_pred_EOS_index = -1
+ wp_pred_max_prob = text_prob[batch_idx][:wp_pred_EOS_index]
+ try:
+ wp_confidence_score = wp_pred_max_prob.cumprod(
+ dim=0)[-1].cpu().numpy().sum()
+ except:
+ wp_confidence_score = 0.0
+ else:
+ wp_confidence_score = 1.0
+ result_list.append((wp_pred, wp_confidence_score))
+ return result_list
diff --git a/openrec/postprocess/nrtr_postprocess.py b/openrec/postprocess/nrtr_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcee427c303f6079d1f7a5251b1b49b738c60aae
--- /dev/null
+++ b/openrec/postprocess/nrtr_postprocess.py
@@ -0,0 +1,75 @@
+import numpy as np
+import torch
+
+from .ctc_postprocess import BaseRecLabelDecode
+
+
+class NRTRLabelDecode(BaseRecLabelDecode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ character_dict_path=None,
+ use_space_char=True,
+ **kwargs):
+ super(NRTRLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+ preds = preds['res']
+ if len(preds) == 2:
+ preds_id = preds[0]
+ preds_prob = preds[1]
+ if isinstance(preds_id, torch.Tensor):
+ preds_id = preds_id.detach().cpu().numpy()
+ if isinstance(preds_prob, torch.Tensor):
+ preds_prob = preds_prob.detach().cpu().numpy()
+ if preds_id[0][0] == 2:
+ preds_idx = preds_id[:, 1:]
+ preds_prob = preds_prob[:, 1:]
+ else:
+ preds_idx = preds_id
+ text = self.decode(preds_idx,
+ preds_prob,
+ is_remove_duplicate=False)
+ if batch is None:
+ return text
+ label = self.decode(batch[1][:, 1:].cpu().numpy())
+ else:
+ if isinstance(preds, torch.Tensor):
+ preds = preds.detach().cpu().numpy()
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx,
+ preds_prob,
+ is_remove_duplicate=False)
+ if batch is None:
+ return text
+ label = self.decode(batch[1][:, 1:].cpu().numpy())
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = ['blank', '', '', ''] + dict_character
+ return dict_character
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """convert text-index into text-label."""
+ result_list = []
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ try:
+ char_idx = self.character[int(text_index[batch_idx][idx])]
+ except:
+ continue
+ if char_idx == '': # end
+ break
+ char_list.append(char_idx)
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
diff --git a/openrec/postprocess/smtr_postprocess.py b/openrec/postprocess/smtr_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..2546943faa0b5b328e6dd81cd8b87a5986648dd6
--- /dev/null
+++ b/openrec/postprocess/smtr_postprocess.py
@@ -0,0 +1,73 @@
+import numpy as np
+import torch
+
+from .ctc_postprocess import BaseRecLabelDecode
+
+
+class SMTRLabelDecode(BaseRecLabelDecode):
+ """Convert between text-label and text-index."""
+
+ BOS = ''
+ EOS = ''
+ IN_F = '' # ignore
+ IN_B = '' # ignore
+ PAD = ''
+
+ def __init__(self,
+ character_dict_path=None,
+ use_space_char=True,
+ next_mode=True,
+ **kwargs):
+ super(SMTRLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+ self.next_mode = next_mode
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+ if isinstance(preds, list):
+ preds = preds[-1]
+ if isinstance(preds, torch.Tensor):
+ preds = preds.detach().cpu().numpy()
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if batch is None:
+ return text
+ label = batch[1]
+ label = self.decode(label[:, 1:].detach().cpu().numpy())
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = [self.EOS] + dict_character + [
+ self.BOS, self.IN_F, self.IN_B, self.PAD
+ ]
+ self.num_character = len(dict_character)
+ return dict_character
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """convert text-index into text-label."""
+ result_list = []
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ try:
+ char_idx = self.character[int(text_index[batch_idx][idx])]
+ except:
+ continue
+ if char_idx == '': # end
+ break
+ if char_idx == '' or char_idx == '':
+ continue
+ char_list.append(char_idx)
+
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ if self.next_mode or text_prob is None:
+ text = ''.join(char_list)
+ else:
+ text = ''.join(char_list[::-1])
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
diff --git a/openrec/postprocess/srn_postprocess.py b/openrec/postprocess/srn_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..24b6d66a67335bf13d07164ca4f97aa9d2bf053d
--- /dev/null
+++ b/openrec/postprocess/srn_postprocess.py
@@ -0,0 +1,80 @@
+import numpy as np
+import torch
+
+from .ctc_postprocess import BaseRecLabelDecode
+
+
+class SRNLabelDecode(BaseRecLabelDecode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(SRNLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+ self.max_len = 25
+
+ def add_special_char(self, dict_character):
+ dict_character = dict_character + ['', '']
+ self.start_idx = len(dict_character) - 2
+ self.end_idx = len(dict_character) - 1
+ return dict_character
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """convert text-index into text-label."""
+ result_list = []
+ ignored_tokens = self.get_ignored_tokens()
+ # [B,25]
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ # print(f"text_index[{batch_idx}][{idx}]:{text_index[batch_idx][idx]}")
+ if text_index[batch_idx][idx] in ignored_tokens:
+ continue
+ if int(text_index[batch_idx][idx]) == int(self.end_idx):
+ if text_prob is None and idx == 0:
+ continue
+ else:
+ break
+ if is_remove_duplicate:
+ # only for predict
+ if idx > 0 and text_index[batch_idx][
+ idx - 1] == text_index[batch_idx][idx]:
+ continue
+ char_list.append(self.character[int(
+ text_index[batch_idx][idx])])
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+
+ if isinstance(preds, torch.Tensor):
+ preds = preds.reshape([-1, self.max_len, preds.shape[-1]])
+ preds = preds.detach().cpu().numpy()
+ else:
+ preds = preds[-1]
+ preds = preds.reshape([-1, self.max_len,
+ preds.shape[-1]]).detach().cpu().numpy()
+
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+
+ if batch is None:
+ return text
+
+ label = batch[1].cpu().numpy()
+ # print(f"label.shape:{label.shape}")
+ label = self.decode(label, is_remove_duplicate=False)
+ return text, label
+
+ def get_ignored_tokens(self):
+ return [self.start_idx, self.end_idx]
diff --git a/openrec/postprocess/visionlan_postprocess.py b/openrec/postprocess/visionlan_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0bfe82653e402d935526e4fa469ca6cf7509444
--- /dev/null
+++ b/openrec/postprocess/visionlan_postprocess.py
@@ -0,0 +1,81 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from .ctc_postprocess import BaseRecLabelDecode
+
+
+class VisionLANLabelDecode(BaseRecLabelDecode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(VisionLANLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+ self.max_text_length = kwargs.get('max_text_length', 25)
+ self.nclass = len(self.character) + 1
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """convert text-index into text-label."""
+ result_list = []
+ ignored_tokens = self.get_ignored_tokens()
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ selection = np.ones(len(text_index[batch_idx]), dtype=bool)
+ if is_remove_duplicate:
+ selection[1:] = text_index[batch_idx][1:] != text_index[
+ batch_idx][:-1]
+ for ignored_token in ignored_tokens:
+ selection &= text_index[batch_idx] != ignored_token
+
+ char_list = [
+ self.character[text_id - 1]
+ for text_id in text_index[batch_idx][selection]
+ ]
+ if text_prob is not None:
+ conf_list = text_prob[batch_idx][selection]
+ else:
+ conf_list = [1] * len(selection)
+ if len(conf_list) == 0:
+ conf_list = [0]
+
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+ if len(preds) == 2: # eval mode
+ net_out, length = preds
+ if batch is not None:
+ label = batch[1]
+
+ else: # train mode
+ net_out = preds[0]
+ label, length = batch[1], batch[5]
+ net_out = torch.cat([t[:l] for t, l in zip(net_out, length)],
+ dim=0)
+ text = []
+ if not isinstance(net_out, torch.Tensor):
+ net_out = torch.tensor(net_out, dtype=torch.float32)
+ net_out = F.softmax(net_out, dim=1)
+ for i in range(0, length.shape[0]):
+ preds_idx = (net_out[int(length[:i].sum()):int(length[:i].sum() +
+ length[i])].topk(1)
+ [1][:, 0].tolist())
+ preds_text = ''.join([
+ self.character[idx - 1]
+ if idx > 0 and idx <= len(self.character) else ''
+ for idx in preds_idx
+ ])
+ preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum() +
+ length[i])].topk(
+ 1)[0][:, 0]
+ preds_prob = torch.exp(
+ torch.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
+ text.append((preds_text, float(preds_prob)))
+ if batch is None:
+ return text
+ label = self.decode(label.detach().cpu().numpy())
+ return text, label
diff --git a/openrec/preprocess/__init__.py b/openrec/preprocess/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0558b30e2f77477adce49cccf2e664f9c0d6d106
--- /dev/null
+++ b/openrec/preprocess/__init__.py
@@ -0,0 +1,173 @@
+import io
+
+import cv2
+import numpy as np
+from PIL import Image
+
+from .abinet_label_encode import ABINetLabelEncode
+from .ar_label_encode import ARLabelEncode
+from .ce_label_encode import CELabelEncode
+from .char_label_encode import CharLabelEncode
+from .cppd_label_encode import CPPDLabelEncode
+from .ctc_label_encode import CTCLabelEncode
+from .ep_label_encode import EPLabelEncode
+from .igtr_label_encode import IGTRLabelEncode
+from .mgp_label_encode import MGPLabelEncode
+from .rec_aug import ABINetAug
+from .rec_aug import BaseDataAugmentation as BDA
+from .rec_aug import PARSeqAug, PARSeqAugPIL, SVTRAug
+from .resize import (ABINetResize, CDistNetResize, LongResize, RecTVResize,
+ RobustScannerRecResizeImg, SliceResize, SliceTVResize,
+ SRNRecResizeImg, SVTRResize, VisionLANResize,
+ RecDynamicResize)
+from .smtr_label_encode import SMTRLabelEncode
+from .srn_label_encode import SRNLabelEncode
+from .visionlan_label_encode import VisionLANLabelEncode
+from .cam_label_encode import CAMLabelEncode
+
+
+class KeepKeys(object):
+
+ def __init__(self, keep_keys, **kwargs):
+ self.keep_keys = keep_keys
+
+ def __call__(self, data):
+ data_list = []
+ for key in self.keep_keys:
+ data_list.append(data[key])
+ return data_list
+
+
+def transform(data, ops=None):
+ """transform."""
+ if ops is None:
+ ops = []
+ for op in ops:
+ data = op(data)
+ if data is None:
+ return None
+ return data
+
+
+class Fasttext(object):
+
+ def __init__(self, path='None', **kwargs):
+ # pip install fasttext==0.9.1
+ import fasttext
+
+ self.fast_model = fasttext.load_model(path)
+
+ def __call__(self, data):
+ label = data['label']
+ fast_label = self.fast_model[label]
+ data['fast_label'] = fast_label
+ return data
+
+
+class DecodeImage(object):
+ """decode image."""
+
+ def __init__(self,
+ img_mode='RGB',
+ channel_first=False,
+ ignore_orientation=False,
+ **kwargs):
+ self.img_mode = img_mode
+ self.channel_first = channel_first
+ self.ignore_orientation = ignore_orientation
+
+ def __call__(self, data):
+ img = data['image']
+
+ assert type(img) is bytes and len(
+ img) > 0, "invalid input 'img' in DecodeImage"
+ img = np.frombuffer(img, dtype='uint8')
+ if self.ignore_orientation:
+ img = cv2.imdecode(
+ img, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR)
+ else:
+ img = cv2.imdecode(img, 1)
+ if img is None:
+ return None
+ if self.img_mode == 'GRAY':
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ elif self.img_mode == 'RGB':
+ assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
+ img.shape)
+ img = img[:, :, ::-1]
+
+ if self.channel_first:
+ img = img.transpose((2, 0, 1))
+
+ data['image'] = img
+ return data
+
+
+class DecodeImagePIL(object):
+ """decode image."""
+
+ def __init__(self, img_mode='RGB', **kwargs):
+ self.img_mode = img_mode
+
+ def __call__(self, data):
+ img = data['image']
+ assert type(img) is bytes and len(
+ img) > 0, "invalid input 'img' in DecodeImage"
+ img = data['image']
+ buf = io.BytesIO(img)
+ img = Image.open(buf).convert('RGB')
+ if self.img_mode == 'Gray':
+ img = img.convert('L')
+ elif self.img_mode == 'BGR':
+ img = np.array(img)[:, :, ::-1] # 将图片转为numpy格式,并将最后一维通道倒序
+ img = Image.fromarray(np.uint8(img))
+ data['image'] = img
+ return data
+
+
+def create_operators(op_param_list, global_config=None):
+ """create operators based on the config.
+
+ Args:
+ params(list): a dict list, used to create some operators
+ """
+ assert isinstance(op_param_list, list), 'operator config should be a list'
+ ops = []
+ for operator in op_param_list:
+ assert isinstance(operator,
+ dict) and len(operator) == 1, 'yaml format error'
+ op_name = list(operator)[0]
+ param = {} if operator[op_name] is None else operator[op_name]
+ if global_config is not None:
+ param.update(global_config)
+ op = eval(op_name)(**param)
+ ops.append(op)
+ return ops
+
+
+class GTCLabelEncode():
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ gtc_label_encode,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ self.gtc_label_encode = eval(gtc_label_encode['name'])(
+ max_text_length=max_text_length,
+ character_dict_path=character_dict_path,
+ use_space_char=use_space_char,
+ **gtc_label_encode)
+ self.ctc_label_encode = CTCLabelEncode(max_text_length,
+ character_dict_path,
+ use_space_char)
+
+ def __call__(self, data):
+ data_ctc = self.ctc_label_encode({'label': data['label']})
+ data = self.gtc_label_encode(data)
+ if data_ctc is None or data is None:
+ return None
+ data['ctc_label'] = data_ctc['label']
+ data['ctc_length'] = data_ctc['length']
+ return data
diff --git a/openrec/preprocess/__pycache__/__init__.cpython-38.pyc b/openrec/preprocess/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2092c868f975c81ff1b507ab80048633f9483343
Binary files /dev/null and b/openrec/preprocess/__pycache__/__init__.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/abinet_aug.cpython-38.pyc b/openrec/preprocess/__pycache__/abinet_aug.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2476f7421139fd62008ccad27366bcf50dccfcbc
Binary files /dev/null and b/openrec/preprocess/__pycache__/abinet_aug.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/abinet_label_encode.cpython-38.pyc b/openrec/preprocess/__pycache__/abinet_label_encode.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f0c98cc3d3fbcfa9a31ef1425b067805a7d7a71
Binary files /dev/null and b/openrec/preprocess/__pycache__/abinet_label_encode.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/ar_label_encode.cpython-38.pyc b/openrec/preprocess/__pycache__/ar_label_encode.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75e891a03e58151cd953eb3338268dddbd303166
Binary files /dev/null and b/openrec/preprocess/__pycache__/ar_label_encode.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/auto_augment.cpython-38.pyc b/openrec/preprocess/__pycache__/auto_augment.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..897199a8b27f7bb45a9e38e66b58c3d229ef48b2
Binary files /dev/null and b/openrec/preprocess/__pycache__/auto_augment.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/cam_label_encode.cpython-38.pyc b/openrec/preprocess/__pycache__/cam_label_encode.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3e5ab55358664f26d297d900ba04475c4580658
Binary files /dev/null and b/openrec/preprocess/__pycache__/cam_label_encode.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/ce_label_encode.cpython-38.pyc b/openrec/preprocess/__pycache__/ce_label_encode.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b5b208e9152009a58f553f51d5288c0a4aeebc3
Binary files /dev/null and b/openrec/preprocess/__pycache__/ce_label_encode.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/char_label_encode.cpython-38.pyc b/openrec/preprocess/__pycache__/char_label_encode.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f284ee928ec6611a93e637387102209376913490
Binary files /dev/null and b/openrec/preprocess/__pycache__/char_label_encode.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/cppd_label_encode.cpython-38.pyc b/openrec/preprocess/__pycache__/cppd_label_encode.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54359b92805d47636f098aecec30bac2c6368360
Binary files /dev/null and b/openrec/preprocess/__pycache__/cppd_label_encode.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/ctc_label_encode.cpython-38.pyc b/openrec/preprocess/__pycache__/ctc_label_encode.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..80b700f2443d05535dbac9e95c209b57bad8a391
Binary files /dev/null and b/openrec/preprocess/__pycache__/ctc_label_encode.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/ep_label_encode.cpython-38.pyc b/openrec/preprocess/__pycache__/ep_label_encode.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab8fc6efddb9dc253a2586a8c8f7f703b7b7a2cb
Binary files /dev/null and b/openrec/preprocess/__pycache__/ep_label_encode.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/igtr_label_encode.cpython-38.pyc b/openrec/preprocess/__pycache__/igtr_label_encode.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9800f59ff807c7925ef0ab79ca271bd78fc8ce6a
Binary files /dev/null and b/openrec/preprocess/__pycache__/igtr_label_encode.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/mgp_label_encode.cpython-38.pyc b/openrec/preprocess/__pycache__/mgp_label_encode.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20eabbf557e9f713448d5c0c7cb3f8a091bde26b
Binary files /dev/null and b/openrec/preprocess/__pycache__/mgp_label_encode.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/parseq_aug.cpython-38.pyc b/openrec/preprocess/__pycache__/parseq_aug.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..baeb4d5618cf4bab5e4019bd1dc6307d490c432f
Binary files /dev/null and b/openrec/preprocess/__pycache__/parseq_aug.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/rec_aug.cpython-38.pyc b/openrec/preprocess/__pycache__/rec_aug.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6796263c333be1f88a92fafb809554787cb44816
Binary files /dev/null and b/openrec/preprocess/__pycache__/rec_aug.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/resize.cpython-38.pyc b/openrec/preprocess/__pycache__/resize.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18dec1d561377eee0914a78bb3f7d385105dd1ac
Binary files /dev/null and b/openrec/preprocess/__pycache__/resize.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/smtr_label_encode.cpython-38.pyc b/openrec/preprocess/__pycache__/smtr_label_encode.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb738380b05f29a6d3a656c148a794c7e7dbdac3
Binary files /dev/null and b/openrec/preprocess/__pycache__/smtr_label_encode.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/srn_label_encode.cpython-38.pyc b/openrec/preprocess/__pycache__/srn_label_encode.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29eb215b21bd1f31b44fb52ba493efc28e6222df
Binary files /dev/null and b/openrec/preprocess/__pycache__/srn_label_encode.cpython-38.pyc differ
diff --git a/openrec/preprocess/__pycache__/visionlan_label_encode.cpython-38.pyc b/openrec/preprocess/__pycache__/visionlan_label_encode.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c33f9d3af12297578cfdc6369650a7ca03d7646e
Binary files /dev/null and b/openrec/preprocess/__pycache__/visionlan_label_encode.cpython-38.pyc differ
diff --git a/openrec/preprocess/abinet_aug.py b/openrec/preprocess/abinet_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..07ceb9aa3ac2add475e02329c4b5e02454aa6730
--- /dev/null
+++ b/openrec/preprocess/abinet_aug.py
@@ -0,0 +1,473 @@
+"""This code is refer from:
+
+https://github.com/FangShancheng/ABINet/blob/main/transforms.py
+"""
+import math
+import numbers
+import random
+
+import cv2
+import numpy as np
+from PIL import Image
+from torchvision.transforms import ColorJitter, Compose
+
+
+def sample_asym(magnitude, size=None):
+ return np.random.beta(1, 4, size) * magnitude
+
+
+def sample_sym(magnitude, size=None):
+ return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
+
+
+def sample_uniform(low, high, size=None):
+ return np.random.uniform(low, high, size=size)
+
+
+def get_interpolation(type='random'):
+ if type == 'random':
+ choice = [
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC,
+ cv2.INTER_AREA
+ ]
+ interpolation = choice[random.randint(0, len(choice) - 1)]
+ elif type == 'nearest':
+ interpolation = cv2.INTER_NEAREST
+ elif type == 'linear':
+ interpolation = cv2.INTER_LINEAR
+ elif type == 'cubic':
+ interpolation = cv2.INTER_CUBIC
+ elif type == 'area':
+ interpolation = cv2.INTER_AREA
+ else:
+ raise TypeError(
+ 'Interpolation types only nearest, linear, cubic, area are supported!'
+ )
+ return interpolation
+
+
+class CVRandomRotation(object):
+
+ def __init__(self, degrees=15):
+ assert isinstance(degrees,
+ numbers.Number), 'degree should be a single number.'
+ assert degrees >= 0, 'degree must be positive.'
+ self.degrees = degrees
+
+ @staticmethod
+ def get_params(degrees):
+ return sample_sym(degrees)
+
+ def __call__(self, img):
+ angle = self.get_params(self.degrees)
+ src_h, src_w = img.shape[:2]
+ M = cv2.getRotationMatrix2D(center=(src_w / 2, src_h / 2),
+ angle=angle,
+ scale=1.0)
+ abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1])
+ dst_w = int(src_h * abs_sin + src_w * abs_cos)
+ dst_h = int(src_h * abs_cos + src_w * abs_sin)
+ M[0, 2] += (dst_w - src_w) / 2
+ M[1, 2] += (dst_h - src_h) / 2
+
+ flags = get_interpolation()
+ return cv2.warpAffine(img,
+ M, (dst_w, dst_h),
+ flags=flags,
+ borderMode=cv2.BORDER_REPLICATE)
+
+
+class CVRandomAffine(object):
+
+ def __init__(self, degrees, translate=None, scale=None, shear=None):
+ assert isinstance(degrees,
+ numbers.Number), 'degree should be a single number.'
+ assert degrees >= 0, 'degree must be positive.'
+ self.degrees = degrees
+
+ if translate is not None:
+ assert (
+ isinstance(translate, (tuple, list)) and len(translate) == 2
+ ), 'translate should be a list or tuple and it must be of length 2.'
+ for t in translate:
+ if not (0.0 <= t <= 1.0):
+ raise ValueError(
+ 'translation values should be between 0 and 1')
+ self.translate = translate
+
+ if scale is not None:
+ assert (
+ isinstance(scale, (tuple, list)) and len(scale) == 2
+ ), 'scale should be a list or tuple and it must be of length 2.'
+ for s in scale:
+ if s <= 0:
+ raise ValueError('scale values should be positive')
+ self.scale = scale
+
+ if shear is not None:
+ if isinstance(shear, numbers.Number):
+ if shear < 0:
+ raise ValueError(
+ 'If shear is a single number, it must be positive.')
+ self.shear = [shear]
+ else:
+ assert isinstance(shear, (tuple, list)) and (
+ len(shear) == 2
+ ), 'shear should be a list or tuple and it must be of length 2.'
+ self.shear = shear
+ else:
+ self.shear = shear
+
+ def _get_inverse_affine_matrix(self, center, angle, translate, scale,
+ shear):
+ # https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717
+ from numpy import cos, sin, tan
+
+ if isinstance(shear, numbers.Number):
+ shear = [shear, 0]
+
+ if not isinstance(shear, (tuple, list)) and len(shear) == 2:
+ raise ValueError(
+ 'Shear should be a single value or a tuple/list containing ' +
+ 'two values. Got {}'.format(shear))
+
+ rot = math.radians(angle)
+ sx, sy = [math.radians(s) for s in shear]
+
+ cx, cy = center
+ tx, ty = translate
+
+ # RSS without scaling
+ a = cos(rot - sy) / cos(sy)
+ b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
+ c = sin(rot - sy) / cos(sy)
+ d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
+
+ # Inverted rotation matrix with scale and shear
+ # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
+ M = [d, -b, 0, -c, a, 0]
+ M = [x / scale for x in M]
+
+ # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
+ M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
+ M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
+
+ # Apply center translation: C * RSS^-1 * C^-1 * T^-1
+ M[2] += cx
+ M[5] += cy
+ return M
+
+ @staticmethod
+ def get_params(degrees, translate, scale_ranges, shears, height):
+ angle = sample_sym(degrees)
+ if translate is not None:
+ max_dx = translate[0] * height
+ max_dy = translate[1] * height
+ translations = (np.round(sample_sym(max_dx)),
+ np.round(sample_sym(max_dy)))
+ else:
+ translations = (0, 0)
+
+ if scale_ranges is not None:
+ scale = sample_uniform(scale_ranges[0], scale_ranges[1])
+ else:
+ scale = 1.0
+
+ if shears is not None:
+ if len(shears) == 1:
+ shear = [sample_sym(shears[0]), 0.0]
+ elif len(shears) == 2:
+ shear = [sample_sym(shears[0]), sample_sym(shears[1])]
+ else:
+ shear = 0.0
+
+ return angle, translations, scale, shear
+
+ def __call__(self, img):
+ src_h, src_w = img.shape[:2]
+ angle, translate, scale, shear = self.get_params(
+ self.degrees, self.translate, self.scale, self.shear, src_h)
+
+ M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle,
+ (0, 0), scale, shear)
+ M = np.array(M).reshape(2, 3)
+
+ startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1),
+ (0, src_h - 1)]
+ project = lambda x, y, a, b, c: int(a * x + b * y + c)
+ endpoints = [(project(x, y, *M[0]), project(x, y, *M[1]))
+ for x, y in startpoints]
+
+ rect = cv2.minAreaRect(np.array(endpoints))
+ bbox = cv2.boxPoints(rect).astype(dtype=np.int32)
+ max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
+ min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
+
+ dst_w = int(max_x - min_x)
+ dst_h = int(max_y - min_y)
+ M[0, 2] += (dst_w - src_w) / 2
+ M[1, 2] += (dst_h - src_h) / 2
+
+ # add translate
+ dst_w += int(abs(translate[0]))
+ dst_h += int(abs(translate[1]))
+ if translate[0] < 0:
+ M[0, 2] += abs(translate[0])
+ if translate[1] < 0:
+ M[1, 2] += abs(translate[1])
+
+ flags = get_interpolation()
+ return cv2.warpAffine(img,
+ M, (dst_w, dst_h),
+ flags=flags,
+ borderMode=cv2.BORDER_REPLICATE)
+
+
+class CVRandomPerspective(object):
+
+ def __init__(self, distortion=0.5):
+ self.distortion = distortion
+
+ def get_params(self, width, height, distortion):
+ offset_h = sample_asym(distortion * height / 2,
+ size=4).astype(dtype=np.int32)
+ offset_w = sample_asym(distortion * width / 2,
+ size=4).astype(dtype=np.int32)
+ topleft = (offset_w[0], offset_h[0])
+ topright = (width - 1 - offset_w[1], offset_h[1])
+ botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
+ botleft = (offset_w[3], height - 1 - offset_h[3])
+
+ startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1),
+ (0, height - 1)]
+ endpoints = [topleft, topright, botright, botleft]
+ return np.array(startpoints,
+ dtype=np.float32), np.array(endpoints,
+ dtype=np.float32)
+
+ def __call__(self, img):
+ height, width = img.shape[:2]
+ startpoints, endpoints = self.get_params(width, height,
+ self.distortion)
+ M = cv2.getPerspectiveTransform(startpoints, endpoints)
+
+ # TODO: more robust way to crop image
+ rect = cv2.minAreaRect(endpoints)
+ bbox = cv2.boxPoints(rect).astype(dtype=np.int32)
+ max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
+ min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
+ min_x, min_y = max(min_x, 0), max(min_y, 0)
+
+ flags = get_interpolation()
+ img = cv2.warpPerspective(img,
+ M, (max_x, max_y),
+ flags=flags,
+ borderMode=cv2.BORDER_REPLICATE)
+ img = img[min_y:, min_x:]
+ return img
+
+
+class CVRescale(object):
+
+ def __init__(self, factor=4, base_size=(128, 512)):
+ """Define image scales using gaussian pyramid and rescale image to
+ target scale.
+
+ Args:
+ factor: the decayed factor from base size, factor=4 keeps target scale by default.
+ base_size: base size the build the bottom layer of pyramid
+ """
+ if isinstance(factor, numbers.Number):
+ self.factor = round(sample_uniform(0, factor))
+ elif isinstance(factor, (tuple, list)) and len(factor) == 2:
+ self.factor = round(sample_uniform(factor[0], factor[1]))
+ else:
+ raise Exception('factor must be number or list with length 2')
+ # assert factor is valid
+ self.base_h, self.base_w = base_size[:2]
+
+ def __call__(self, img):
+ if self.factor == 0:
+ return img
+ src_h, src_w = img.shape[:2]
+ cur_w, cur_h = self.base_w, self.base_h
+ scale_img = cv2.resize(img, (cur_w, cur_h),
+ interpolation=get_interpolation())
+ for _ in range(self.factor):
+ scale_img = cv2.pyrDown(scale_img)
+ scale_img = cv2.resize(scale_img, (src_w, src_h),
+ interpolation=get_interpolation())
+ return scale_img
+
+
+class CVGaussianNoise(object):
+
+ def __init__(self, mean=0, var=20):
+ self.mean = mean
+ if isinstance(var, numbers.Number):
+ self.var = max(int(sample_asym(var)), 1)
+ elif isinstance(var, (tuple, list)) and len(var) == 2:
+ self.var = int(sample_uniform(var[0], var[1]))
+ else:
+ raise Exception('degree must be number or list with length 2')
+
+ def __call__(self, img):
+ noise = np.random.normal(self.mean, self.var**0.5, img.shape)
+ img = np.clip(img + noise, 0, 255).astype(np.uint8)
+ return img
+
+
+class CVMotionBlur(object):
+
+ def __init__(self, degrees=12, angle=90):
+ if isinstance(degrees, numbers.Number):
+ self.degree = max(int(sample_asym(degrees)), 1)
+ elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
+ self.degree = int(sample_uniform(degrees[0], degrees[1]))
+ else:
+ raise Exception('degree must be number or list with length 2')
+ self.angle = sample_uniform(-angle, angle)
+
+ def __call__(self, img):
+ M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2),
+ self.angle, 1)
+ motion_blur_kernel = np.zeros((self.degree, self.degree))
+ motion_blur_kernel[self.degree // 2, :] = 1
+ motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M,
+ (self.degree, self.degree))
+ motion_blur_kernel = motion_blur_kernel / self.degree
+ img = cv2.filter2D(img, -1, motion_blur_kernel)
+ img = np.clip(img, 0, 255).astype(np.uint8)
+ return img
+
+
+class CVGeometry(object):
+
+ def __init__(
+ self,
+ degrees=15,
+ translate=(0.3, 0.3),
+ scale=(0.5, 2.0),
+ shear=(45, 15),
+ distortion=0.5,
+ p=0.5,
+ ):
+ self.p = p
+ type_p = random.random()
+ if type_p < 0.33:
+ self.transforms = CVRandomRotation(degrees=degrees)
+ elif type_p < 0.66:
+ self.transforms = CVRandomAffine(degrees=degrees,
+ translate=translate,
+ scale=scale,
+ shear=shear)
+ else:
+ self.transforms = CVRandomPerspective(distortion=distortion)
+
+ def __call__(self, img):
+ if random.random() < self.p:
+ return self.transforms(img)
+ else:
+ return img
+
+
+class CVDeterioration(object):
+
+ def __init__(self, var, degrees, factor, p=0.5):
+ self.p = p
+ transforms = []
+ if var is not None:
+ transforms.append(CVGaussianNoise(var=var))
+ if degrees is not None:
+ transforms.append(CVMotionBlur(degrees=degrees))
+ if factor is not None:
+ transforms.append(CVRescale(factor=factor))
+
+ random.shuffle(transforms)
+ transforms = Compose(transforms)
+ self.transforms = transforms
+
+ def __call__(self, img):
+ if random.random() < self.p:
+ return self.transforms(img)
+ else:
+ return img
+
+
+class CVColorJitter(object):
+
+ def __init__(self,
+ brightness=0.5,
+ contrast=0.5,
+ saturation=0.5,
+ hue=0.1,
+ p=0.5):
+ self.p = p
+ self.transforms = ColorJitter(brightness=brightness,
+ contrast=contrast,
+ saturation=saturation,
+ hue=hue)
+
+ def __call__(self, img):
+ if random.random() < self.p:
+ return np.array(self.transforms(Image.fromarray(img)))
+ else:
+ return img
+
+
+class SVTRDeterioration(object):
+
+ def __init__(self, var, degrees, factor, p=0.5):
+ self.p = p
+ transforms = []
+ if var is not None:
+ transforms.append(CVGaussianNoise(var=var))
+ if degrees is not None:
+ transforms.append(CVMotionBlur(degrees=degrees))
+ if factor is not None:
+ transforms.append(CVRescale(factor=factor))
+ self.transforms = transforms
+
+ def __call__(self, img):
+ if random.random() < self.p:
+ random.shuffle(self.transforms)
+ transforms = Compose(self.transforms)
+ return transforms(img)
+ else:
+ return img
+
+
+class SVTRGeometry(object):
+
+ def __init__(
+ self,
+ aug_type=0,
+ degrees=15,
+ translate=(0.3, 0.3),
+ scale=(0.5, 2.0),
+ shear=(45, 15),
+ distortion=0.5,
+ p=0.5,
+ ):
+ self.aug_type = aug_type
+ self.p = p
+ self.transforms = []
+ self.transforms.append(CVRandomRotation(degrees=degrees))
+ self.transforms.append(
+ CVRandomAffine(degrees=degrees,
+ translate=translate,
+ scale=scale,
+ shear=shear))
+ self.transforms.append(CVRandomPerspective(distortion=distortion))
+
+ def __call__(self, img):
+ if random.random() < self.p:
+ if self.aug_type:
+ random.shuffle(self.transforms)
+ transforms = Compose(self.transforms[:random.randint(1, 3)])
+ img = transforms(img)
+ else:
+ img = self.transforms[random.randint(0, 2)](img)
+ return img
+ else:
+ return img
diff --git a/openrec/preprocess/abinet_label_encode.py b/openrec/preprocess/abinet_label_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba0ba71de6ae32af803dcf8c6aea95f17265c3bb
--- /dev/null
+++ b/openrec/preprocess/abinet_label_encode.py
@@ -0,0 +1,36 @@
+import numpy as np
+
+from .ctc_label_encode import BaseRecLabelEncode
+
+
+class ABINetLabelEncode(BaseRecLabelEncode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ ignore_index=100,
+ **kwargs):
+
+ super(ABINetLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ use_space_char)
+ self.ignore_index = ignore_index
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) > self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ text.append(0)
+ text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = [''] + dict_character
+ return dict_character
diff --git a/openrec/preprocess/ar_label_encode.py b/openrec/preprocess/ar_label_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..925489fa339e943aaced76d2bd629436a1ffaad2
--- /dev/null
+++ b/openrec/preprocess/ar_label_encode.py
@@ -0,0 +1,36 @@
+import numpy as np
+
+from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
+
+
+class ARLabelEncode(BaseRecLabelEncode):
+ """Convert between text-label and text-index."""
+
+ BOS = ''
+ EOS = ''
+ PAD = ''
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(ARLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ use_space_char)
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ data['length'] = np.array(len(text))
+ text = [self.dict[self.BOS]] + text + [self.dict[self.EOS]]
+ text = text + [self.dict[self.PAD]
+ ] * (self.max_text_len + 2 - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
+ return dict_character
diff --git a/openrec/preprocess/auto_augment.py b/openrec/preprocess/auto_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..518d1f42846890c1296767e95aec738f448e8305
--- /dev/null
+++ b/openrec/preprocess/auto_augment.py
@@ -0,0 +1,1012 @@
+"""AutoAugment, RandAugment, AugMix, and 3-Augment for PyTorch.
+
+This code implements the searched ImageNet policies with various tweaks and improvements and
+does not include any of the search code.
+
+AA and RA Implementation adapted from:
+ https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
+
+AugMix adapted from:
+ https://github.com/google-research/augmix
+
+3-Augment based on: https://github.com/facebookresearch/deit/blob/main/README_revenge.md
+
+Papers:
+ AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
+ Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
+ RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
+ AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
+ 3-Augment: DeiT III: Revenge of the ViT - https://arxiv.org/abs/2204.07118
+
+Hacked together by / Copyright 2019, Ross Wightman
+"""
+import math
+import random
+import re
+from functools import partial
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+from PIL import Image, ImageChops, ImageEnhance, ImageFilter, ImageOps
+
+_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
+
+_FILL = (128, 128, 128)
+
+_LEVEL_DENOM = 10. # denominator for conversion from 'Mx' magnitude scale to fractional aug level for op arguments
+
+_HPARAMS_DEFAULT = dict(
+ translate_const=250,
+ img_mean=_FILL,
+)
+
+if hasattr(Image, 'Resampling'):
+ _RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR,
+ Image.Resampling.BICUBIC)
+ _DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC
+else:
+ _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
+ _DEFAULT_INTERPOLATION = Image.BICUBIC
+
+
+def _interpolation(kwargs):
+ interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
+ if isinstance(interpolation, (list, tuple)):
+ return random.choice(interpolation)
+ return interpolation
+
+
+def _check_args_tf(kwargs):
+ if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
+ kwargs.pop('fillcolor')
+ kwargs['resample'] = _interpolation(kwargs)
+
+
+def shear_x(img, factor, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0),
+ **kwargs)
+
+
+def shear_y(img, factor, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0),
+ **kwargs)
+
+
+def translate_x_rel(img, pct, **kwargs):
+ pixels = pct * img.size[0]
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0),
+ **kwargs)
+
+
+def translate_y_rel(img, pct, **kwargs):
+ pixels = pct * img.size[1]
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels),
+ **kwargs)
+
+
+def translate_x_abs(img, pixels, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0),
+ **kwargs)
+
+
+def translate_y_abs(img, pixels, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels),
+ **kwargs)
+
+
+def rotate(img, degrees, **kwargs):
+ _check_args_tf(kwargs)
+ if _PIL_VER >= (5, 2):
+ return img.rotate(degrees, **kwargs)
+ if _PIL_VER >= (5, 0):
+ w, h = img.size
+ post_trans = (0, 0)
+ rotn_center = (w / 2.0, h / 2.0)
+ angle = -math.radians(degrees)
+ matrix = [
+ round(math.cos(angle), 15),
+ round(math.sin(angle), 15),
+ 0.0,
+ round(-math.sin(angle), 15),
+ round(math.cos(angle), 15),
+ 0.0,
+ ]
+
+ def transform(x, y, matrix):
+ (a, b, c, d, e, f) = matrix
+ return a * x + b * y + c, d * x + e * y + f
+
+ matrix[2], matrix[5] = transform(-rotn_center[0] - post_trans[0],
+ -rotn_center[1] - post_trans[1],
+ matrix)
+ matrix[2] += rotn_center[0]
+ matrix[5] += rotn_center[1]
+ return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
+ return img.rotate(degrees, resample=kwargs['resample'])
+
+
+def auto_contrast(img, **__):
+ return ImageOps.autocontrast(img)
+
+
+def invert(img, **__):
+ return ImageOps.invert(img)
+
+
+def equalize(img, **__):
+ return ImageOps.equalize(img)
+
+
+def solarize(img, thresh, **__):
+ return ImageOps.solarize(img, thresh)
+
+
+def solarize_add(img, add, thresh=128, **__):
+ lut = []
+ for i in range(256):
+ if i < thresh:
+ lut.append(min(255, i + add))
+ else:
+ lut.append(i)
+
+ if img.mode in ('L', 'RGB'):
+ if img.mode == 'RGB' and len(lut) == 256:
+ lut = lut + lut + lut
+ return img.point(lut)
+
+ return img
+
+
+def posterize(img, bits_to_keep, **__):
+ if bits_to_keep >= 8:
+ return img
+ return ImageOps.posterize(img, bits_to_keep)
+
+
+def contrast(img, factor, **__):
+ return ImageEnhance.Contrast(img).enhance(factor)
+
+
+def color(img, factor, **__):
+ return ImageEnhance.Color(img).enhance(factor)
+
+
+def brightness(img, factor, **__):
+ return ImageEnhance.Brightness(img).enhance(factor)
+
+
+def sharpness(img, factor, **__):
+ return ImageEnhance.Sharpness(img).enhance(factor)
+
+
+def gaussian_blur(img, factor, **__):
+ img = img.filter(ImageFilter.GaussianBlur(radius=factor))
+ return img
+
+
+def gaussian_blur_rand(img, factor, **__):
+ radius_min = 0.1
+ radius_max = 2.0
+ img = img.filter(
+ ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max *
+ factor)))
+ return img
+
+
+def desaturate(img, factor, **_):
+ factor = min(1., max(0., 1. - factor))
+ # enhance factor 0 = grayscale, 1.0 = no-change
+ return ImageEnhance.Color(img).enhance(factor)
+
+
+def _randomly_negate(v):
+ """With 50% prob, negate the value."""
+ return -v if random.random() > 0.5 else v
+
+
+def _rotate_level_to_arg(level, _hparams):
+ # range [-30, 30]
+ level = (level / _LEVEL_DENOM) * 30.
+ level = _randomly_negate(level)
+ return level,
+
+
+def _enhance_level_to_arg(level, _hparams):
+ # range [0.1, 1.9]
+ return (level / _LEVEL_DENOM) * 1.8 + 0.1,
+
+
+def _enhance_increasing_level_to_arg(level, _hparams):
+ # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
+ # range [0.1, 1.9] if level <= _LEVEL_DENOM
+ level = (level / _LEVEL_DENOM) * .9
+ level = max(0.1, 1.0 + _randomly_negate(level)) # keep it >= 0.1
+ return level,
+
+
+def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True):
+ level = (level / _LEVEL_DENOM)
+ level = min_val + (max_val - min_val) * level
+ if clamp:
+ level = max(min_val, min(max_val, level))
+ return level,
+
+
+def _shear_level_to_arg(level, _hparams):
+ # range [-0.3, 0.3]
+ level = (level / _LEVEL_DENOM) * 0.3
+ level = _randomly_negate(level)
+ return level,
+
+
+def _translate_abs_level_to_arg(level, hparams):
+ translate_const = hparams['translate_const']
+ level = (level / _LEVEL_DENOM) * float(translate_const)
+ level = _randomly_negate(level)
+ return level,
+
+
+def _translate_rel_level_to_arg(level, hparams):
+ # default range [-0.45, 0.45]
+ translate_pct = hparams.get('translate_pct', 0.45)
+ level = (level / _LEVEL_DENOM) * translate_pct
+ level = _randomly_negate(level)
+ return level,
+
+
+def _posterize_level_to_arg(level, _hparams):
+ # As per Tensorflow TPU EfficientNet impl
+ # range [0, 4], 'keep 0 up to 4 MSB of original image'
+ # intensity/severity of augmentation decreases with level
+ return int((level / _LEVEL_DENOM) * 4),
+
+
+def _posterize_increasing_level_to_arg(level, hparams):
+ # As per Tensorflow models research and UDA impl
+ # range [4, 0], 'keep 4 down to 0 MSB of original image',
+ # intensity/severity of augmentation increases with level
+ return 4 - _posterize_level_to_arg(level, hparams)[0],
+
+
+def _posterize_original_level_to_arg(level, _hparams):
+ # As per original AutoAugment paper description
+ # range [4, 8], 'keep 4 up to 8 MSB of image'
+ # intensity/severity of augmentation decreases with level
+ return int((level / _LEVEL_DENOM) * 4) + 4,
+
+
+def _solarize_level_to_arg(level, _hparams):
+ # range [0, 256]
+ # intensity/severity of augmentation decreases with level
+ return min(256, int((level / _LEVEL_DENOM) * 256)),
+
+
+def _solarize_increasing_level_to_arg(level, _hparams):
+ # range [0, 256]
+ # intensity/severity of augmentation increases with level
+ return 256 - _solarize_level_to_arg(level, _hparams)[0],
+
+
+def _solarize_add_level_to_arg(level, _hparams):
+ # range [0, 110]
+ return min(128, int((level / _LEVEL_DENOM) * 110)),
+
+
+LEVEL_TO_ARG = {
+ 'AutoContrast': None,
+ 'Equalize': None,
+ 'Invert': None,
+ 'Rotate': _rotate_level_to_arg,
+ # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
+ 'Posterize': _posterize_level_to_arg,
+ 'PosterizeIncreasing': _posterize_increasing_level_to_arg,
+ 'PosterizeOriginal': _posterize_original_level_to_arg,
+ 'Solarize': _solarize_level_to_arg,
+ 'SolarizeIncreasing': _solarize_increasing_level_to_arg,
+ 'SolarizeAdd': _solarize_add_level_to_arg,
+ 'Color': _enhance_level_to_arg,
+ 'ColorIncreasing': _enhance_increasing_level_to_arg,
+ 'Contrast': _enhance_level_to_arg,
+ 'ContrastIncreasing': _enhance_increasing_level_to_arg,
+ 'Brightness': _enhance_level_to_arg,
+ 'BrightnessIncreasing': _enhance_increasing_level_to_arg,
+ 'Sharpness': _enhance_level_to_arg,
+ 'SharpnessIncreasing': _enhance_increasing_level_to_arg,
+ 'ShearX': _shear_level_to_arg,
+ 'ShearY': _shear_level_to_arg,
+ 'TranslateX': _translate_abs_level_to_arg,
+ 'TranslateY': _translate_abs_level_to_arg,
+ 'TranslateXRel': _translate_rel_level_to_arg,
+ 'TranslateYRel': _translate_rel_level_to_arg,
+ 'Desaturate': partial(_minmax_level_to_arg, min_val=0.5, max_val=1.0),
+ 'GaussianBlur': partial(_minmax_level_to_arg, min_val=0.1, max_val=2.0),
+ 'GaussianBlurRand': _minmax_level_to_arg,
+}
+
+NAME_TO_OP = {
+ 'AutoContrast': auto_contrast,
+ 'Equalize': equalize,
+ 'Invert': invert,
+ 'Rotate': rotate,
+ 'Posterize': posterize,
+ 'PosterizeIncreasing': posterize,
+ 'PosterizeOriginal': posterize,
+ 'Solarize': solarize,
+ 'SolarizeIncreasing': solarize,
+ 'SolarizeAdd': solarize_add,
+ 'Color': color,
+ 'ColorIncreasing': color,
+ 'Contrast': contrast,
+ 'ContrastIncreasing': contrast,
+ 'Brightness': brightness,
+ 'BrightnessIncreasing': brightness,
+ 'Sharpness': sharpness,
+ 'SharpnessIncreasing': sharpness,
+ 'ShearX': shear_x,
+ 'ShearY': shear_y,
+ 'TranslateX': translate_x_abs,
+ 'TranslateY': translate_y_abs,
+ 'TranslateXRel': translate_x_rel,
+ 'TranslateYRel': translate_y_rel,
+ 'Desaturate': desaturate,
+ 'GaussianBlur': gaussian_blur,
+ 'GaussianBlurRand': gaussian_blur_rand,
+}
+
+
+class AugmentOp:
+
+ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
+ hparams = hparams or _HPARAMS_DEFAULT
+ self.name = name
+ self.aug_fn = NAME_TO_OP[name]
+ self.level_fn = LEVEL_TO_ARG[name]
+ self.prob = prob
+ self.magnitude = magnitude
+ self.hparams = hparams.copy()
+ self.kwargs = dict(
+ fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
+ resample=hparams['interpolation']
+ if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
+ )
+
+ # If magnitude_std is > 0, we introduce some randomness
+ # in the usually fixed policy and sample magnitude from a normal distribution
+ # with mean `magnitude` and std-dev of `magnitude_std`.
+ # NOTE This is my own hack, being tested, not in papers or reference impls.
+ # If magnitude_std is inf, we sample magnitude from a uniform distribution
+ self.magnitude_std = self.hparams.get('magnitude_std', 0)
+ self.magnitude_max = self.hparams.get('magnitude_max', None)
+
+ def __call__(self, img):
+ if self.prob < 1.0 and random.random() > self.prob:
+ return img
+ magnitude = self.magnitude
+ if self.magnitude_std > 0:
+ # magnitude randomization enabled
+ if self.magnitude_std == float('inf'):
+ # inf == uniform sampling
+ magnitude = random.uniform(0, magnitude)
+ elif self.magnitude_std > 0:
+ magnitude = random.gauss(magnitude, self.magnitude_std)
+ # default upper_bound for the timm RA impl is _LEVEL_DENOM (10)
+ # setting magnitude_max overrides this to allow M > 10 (behaviour closer to Google TF RA impl)
+ upper_bound = self.magnitude_max or _LEVEL_DENOM
+ magnitude = max(0., min(magnitude, upper_bound))
+ level_args = self.level_fn(
+ magnitude, self.hparams) if self.level_fn is not None else tuple()
+ return self.aug_fn(img, *level_args, **self.kwargs)
+
+ def __repr__(self):
+ fs = self.__class__.__name__ + f'(name={self.name}, p={self.prob}'
+ fs += f', m={self.magnitude}, mstd={self.magnitude_std}'
+ if self.magnitude_max is not None:
+ fs += f', mmax={self.magnitude_max}'
+ fs += ')'
+ return fs
+
+
+def auto_augment_policy_v0(hparams):
+ # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
+ policy = [
+ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+ [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+ [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+ [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
+ [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+ [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+ [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
+ [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
+ [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
+ [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
+ [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+ [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
+ [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)
+ ], # This results in black image with Tpu posterize
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+ [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
+ ]
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+ return pc
+
+
+def auto_augment_policy_v0r(hparams):
+ # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
+ # in Google research implementation (number of bits discarded increases with magnitude)
+ policy = [
+ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+ [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+ [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+ [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
+ [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+ [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+ [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
+ [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
+ [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
+ [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
+ [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+ [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
+ [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+ [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
+ ]
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+ return pc
+
+
+def auto_augment_policy_original(hparams):
+ # ImageNet policy from https://arxiv.org/abs/1805.09501
+ policy = [
+ [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+ [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+ [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
+ [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
+ [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
+ [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
+ [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
+ [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
+ [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
+ [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+ [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
+ [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
+ [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
+ [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
+ [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+ ]
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+ return pc
+
+
+def auto_augment_policy_originalr(hparams):
+ # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
+ policy = [
+ [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+ [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+ [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
+ [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
+ [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
+ [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
+ [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
+ [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
+ [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
+ [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+ [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
+ [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
+ [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
+ [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
+ [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+ ]
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+ return pc
+
+
+def auto_augment_policy_3a(hparams):
+ policy = [
+ [('Solarize', 1.0, 5)], # 128 solarize threshold @ 5 magnitude
+ [('Desaturate', 1.0, 10)], # grayscale at 10 magnitude
+ [('GaussianBlurRand', 1.0, 10)],
+ ]
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+ return pc
+
+
+def auto_augment_policy(name='v0', hparams=None):
+ hparams = hparams or _HPARAMS_DEFAULT
+ if name == 'original':
+ return auto_augment_policy_original(hparams)
+ if name == 'originalr':
+ return auto_augment_policy_originalr(hparams)
+ if name == 'v0':
+ return auto_augment_policy_v0(hparams)
+ if name == 'v0r':
+ return auto_augment_policy_v0r(hparams)
+ if name == '3a':
+ return auto_augment_policy_3a(hparams)
+ assert False, f'Unknown AA policy {name}'
+
+
+class AutoAugment:
+
+ def __init__(self, policy):
+ self.policy = policy
+
+ def __call__(self, img):
+ sub_policy = random.choice(self.policy)
+ for op in sub_policy:
+ img = op(img)
+ return img
+
+ def __repr__(self):
+ fs = self.__class__.__name__ + '(policy='
+ for p in self.policy:
+ fs += '\n\t['
+ fs += ', '.join([str(op) for op in p])
+ fs += ']'
+ fs += ')'
+ return fs
+
+
+def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
+ """Create a AutoAugment transform.
+
+ Args:
+ config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
+ dashes ('-').
+ The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
+
+ The remaining sections:
+ 'mstd' - float std deviation of magnitude noise applied
+ Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
+
+ hparams: Other hparams (kwargs) for the AutoAugmentation scheme
+
+ Returns:
+ A PyTorch compatible Transform
+ """
+ config = config_str.split('-')
+ policy_name = config[0]
+ config = config[1:]
+ for c in config:
+ cs = re.split(r'(\d.*)', c)
+ if len(cs) < 2:
+ continue
+ key, val = cs[:2]
+ if key == 'mstd':
+ # noise param injected via hparams for now
+ hparams.setdefault('magnitude_std', float(val))
+ else:
+ assert False, 'Unknown AutoAugment config section'
+ aa_policy = auto_augment_policy(policy_name, hparams=hparams)
+ return AutoAugment(aa_policy)
+
+
+_RAND_TRANSFORMS = [
+ 'AutoContrast',
+ 'Equalize',
+ 'Invert',
+ 'Rotate',
+ 'Posterize',
+ 'Solarize',
+ 'SolarizeAdd',
+ 'Color',
+ 'Contrast',
+ 'Brightness',
+ 'Sharpness',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateXRel',
+ 'TranslateYRel',
+ # 'Cutout' # NOTE I've implement this as random erasing separately
+]
+
+_RAND_INCREASING_TRANSFORMS = [
+ 'AutoContrast',
+ 'Equalize',
+ 'Invert',
+ 'Rotate',
+ 'PosterizeIncreasing',
+ 'SolarizeIncreasing',
+ 'SolarizeAdd',
+ 'ColorIncreasing',
+ 'ContrastIncreasing',
+ 'BrightnessIncreasing',
+ 'SharpnessIncreasing',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateXRel',
+ 'TranslateYRel',
+ # 'Cutout' # NOTE I've implement this as random erasing separately
+]
+
+_RAND_3A = [
+ 'SolarizeIncreasing',
+ 'Desaturate',
+ 'GaussianBlur',
+]
+
+_RAND_WEIGHTED_3A = {
+ 'SolarizeIncreasing': 6,
+ 'Desaturate': 6,
+ 'GaussianBlur': 6,
+ 'Rotate': 3,
+ 'ShearX': 2,
+ 'ShearY': 2,
+ 'PosterizeIncreasing': 1,
+ 'AutoContrast': 1,
+ 'ColorIncreasing': 1,
+ 'SharpnessIncreasing': 1,
+ 'ContrastIncreasing': 1,
+ 'BrightnessIncreasing': 1,
+ 'Equalize': 1,
+ 'Invert': 1,
+}
+
+# These experimental weights are based loosely on the relative improvements mentioned in paper.
+# They may not result in increased performance, but could likely be tuned to so.
+_RAND_WEIGHTED_0 = {
+ 'Rotate': 3,
+ 'ShearX': 2,
+ 'ShearY': 2,
+ 'TranslateXRel': 1,
+ 'TranslateYRel': 1,
+ 'ColorIncreasing': .25,
+ 'SharpnessIncreasing': 0.25,
+ 'AutoContrast': 0.25,
+ 'SolarizeIncreasing': .05,
+ 'SolarizeAdd': .05,
+ 'ContrastIncreasing': .05,
+ 'BrightnessIncreasing': .05,
+ 'Equalize': .05,
+ 'PosterizeIncreasing': 0.05,
+ 'Invert': 0.05,
+}
+
+
+def _get_weighted_transforms(transforms: Dict):
+ transforms, probs = list(zip(*transforms.items()))
+ probs = np.array(probs)
+ probs = probs / np.sum(probs)
+ return transforms, probs
+
+
+def rand_augment_choices(name: str, increasing=True):
+ if name == 'weights':
+ return _RAND_WEIGHTED_0
+ if name == '3aw':
+ return _RAND_WEIGHTED_3A
+ if name == '3a':
+ return _RAND_3A
+ return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
+
+
+def rand_augment_ops(
+ magnitude: Union[int, float] = 10,
+ prob: float = 0.5,
+ hparams: Optional[Dict] = None,
+ transforms: Optional[Union[Dict, List]] = None,
+):
+ hparams = hparams or _HPARAMS_DEFAULT
+ transforms = transforms or _RAND_TRANSFORMS
+ return [
+ AugmentOp(name, prob=prob, magnitude=magnitude, hparams=hparams)
+ for name in transforms
+ ]
+
+
+class RandAugment:
+
+ def __init__(self, ops, num_layers=2, choice_weights=None):
+ self.ops = ops
+ self.num_layers = num_layers
+ self.choice_weights = choice_weights
+
+ def __call__(self, img):
+ # no replacement when using weighted choice
+ ops = np.random.choice(
+ self.ops,
+ self.num_layers,
+ replace=self.choice_weights is None,
+ p=self.choice_weights,
+ )
+ for op in ops:
+ img = op(img)
+ return img
+
+ def __repr__(self):
+ fs = self.__class__.__name__ + f'(n={self.num_layers}, ops='
+ for op in self.ops:
+ fs += f'\n\t{op}'
+ fs += ')'
+ return fs
+
+
+def rand_augment_transform(
+ config_str: str,
+ hparams: Optional[Dict] = None,
+ transforms: Optional[Union[str, Dict, List]] = None,
+):
+ """Create a RandAugment transform.
+
+ Args:
+ config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
+ by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
+ The remaining sections, not order sepecific determine
+ 'm' - integer magnitude of rand augment
+ 'n' - integer num layers (number of transform ops selected per image)
+ 'p' - float probability of applying each layer (default 0.5)
+ 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
+ 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
+ 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
+ 't' - str name of transform set to use
+ Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
+ 'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2
+
+ hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme
+
+ Returns:
+ A PyTorch compatible Transform
+ """
+ magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
+ num_layers = 2 # default to 2 ops per image
+ increasing = False
+ prob = 0.5
+ config = config_str.split('-')
+ assert config[0] == 'rand'
+ config = config[1:]
+ for c in config:
+ if c.startswith('t'):
+ # NOTE old 'w' key was removed, 'w0' is not equivalent to 'tweights'
+ val = str(c[1:])
+ if transforms is None:
+ transforms = val
+ else:
+ # numeric options
+ cs = re.split(r'(\d.*)', c)
+ if len(cs) < 2:
+ continue
+ key, val = cs[:2]
+ if key == 'mstd':
+ # noise param / randomization of magnitude values
+ mstd = float(val)
+ if mstd > 100:
+ # use uniform sampling in 0 to magnitude if mstd is > 100
+ mstd = float('inf')
+ hparams.setdefault('magnitude_std', mstd)
+ elif key == 'mmax':
+ # clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
+ hparams.setdefault('magnitude_max', int(val))
+ elif key == 'inc':
+ if bool(val):
+ increasing = True
+ elif key == 'm':
+ magnitude = int(val)
+ elif key == 'n':
+ num_layers = int(val)
+ elif key == 'p':
+ prob = float(val)
+ else:
+ assert False, 'Unknown RandAugment config section'
+
+ if isinstance(transforms, str):
+ transforms = rand_augment_choices(transforms, increasing=increasing)
+ elif transforms is None:
+ transforms = _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
+
+ choice_weights = None
+ if isinstance(transforms, Dict):
+ transforms, choice_weights = _get_weighted_transforms(transforms)
+
+ ra_ops = rand_augment_ops(magnitude=magnitude,
+ prob=prob,
+ hparams=hparams,
+ transforms=transforms)
+ return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
+
+
+_AUGMIX_TRANSFORMS = [
+ 'AutoContrast',
+ 'ColorIncreasing', # not in paper
+ 'ContrastIncreasing', # not in paper
+ 'BrightnessIncreasing', # not in paper
+ 'SharpnessIncreasing', # not in paper
+ 'Equalize',
+ 'Rotate',
+ 'PosterizeIncreasing',
+ 'SolarizeIncreasing',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateXRel',
+ 'TranslateYRel',
+]
+
+
+def augmix_ops(
+ magnitude: Union[int, float] = 10,
+ hparams: Optional[Dict] = None,
+ transforms: Optional[Union[str, Dict, List]] = None,
+):
+ hparams = hparams or _HPARAMS_DEFAULT
+ transforms = transforms or _AUGMIX_TRANSFORMS
+ return [
+ AugmentOp(name, prob=1.0, magnitude=magnitude, hparams=hparams)
+ for name in transforms
+ ]
+
+
+class AugMixAugment:
+ """AugMix Transform Adapted and improved from impl here:
+ https://github.com/google-research/augmix/blob/master/imagenet.py From
+ paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and
+ Uncertainty - https://arxiv.org/abs/1912.02781."""
+
+ def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
+ self.ops = ops
+ self.alpha = alpha
+ self.width = width
+ self.depth = depth
+ self.blended = blended # blended mode is faster but not well tested
+
+ def _calc_blended_weights(self, ws, m):
+ ws = ws * m
+ cump = 1.
+ rws = []
+ for w in ws[::-1]:
+ alpha = w / cump
+ cump *= (1 - alpha)
+ rws.append(alpha)
+ return np.array(rws[::-1], dtype=np.float32)
+
+ def _apply_blended(self, img, mixing_weights, m):
+ # This is my first crack and implementing a slightly faster mixed augmentation. Instead
+ # of accumulating the mix for each chain in a Numpy array and then blending with original,
+ # it recomputes the blending coefficients and applies one PIL image blend per chain.
+ # TODO the results appear in the right ballpark but they differ by more than rounding.
+ img_orig = img.copy()
+ ws = self._calc_blended_weights(mixing_weights, m)
+ for w in ws:
+ depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
+ ops = np.random.choice(self.ops, depth, replace=True)
+ img_aug = img_orig # no ops are in-place, deep copy not necessary
+ for op in ops:
+ img_aug = op(img_aug)
+ img = Image.blend(img, img_aug, w)
+ return img
+
+ def _apply_basic(self, img, mixing_weights, m):
+ # This is a literal adaptation of the paper/official implementation without normalizations and
+ # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
+ # typical augmentation transforms, could use a GPU / Kornia implementation.
+ img_shape = img.size[0], img.size[1], len(img.getbands())
+ mixed = np.zeros(img_shape, dtype=np.float32)
+ for mw in mixing_weights:
+ depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
+ ops = np.random.choice(self.ops, depth, replace=True)
+ img_aug = img # no ops are in-place, deep copy not necessary
+ for op in ops:
+ img_aug = op(img_aug)
+ mixed += mw * np.asarray(img_aug, dtype=np.float32)
+ np.clip(mixed, 0, 255., out=mixed)
+ mixed = Image.fromarray(mixed.astype(np.uint8))
+ return Image.blend(img, mixed, m)
+
+ def __call__(self, img):
+ mixing_weights = np.float32(
+ np.random.dirichlet([self.alpha] * self.width))
+ m = np.float32(np.random.beta(self.alpha, self.alpha))
+ if self.blended:
+ mixed = self._apply_blended(img, mixing_weights, m)
+ else:
+ mixed = self._apply_basic(img, mixing_weights, m)
+ return mixed
+
+ def __repr__(self):
+ fs = self.__class__.__name__ + f'(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops='
+ for op in self.ops:
+ fs += f'\n\t{op}'
+ fs += ')'
+ return fs
+
+
+def augment_and_mix_transform(config_str: str, hparams: Optional[Dict] = None):
+ """Create AugMix PyTorch transform.
+
+ Args:
+ config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
+ by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
+ The remaining sections, not order sepecific determine
+ 'm' - integer magnitude (severity) of augmentation mix (default: 3)
+ 'w' - integer width of augmentation chain (default: 3)
+ 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
+ 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
+ 'mstd' - float std deviation of magnitude noise applied (default: 0)
+ Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
+
+ hparams: Other hparams (kwargs) for the Augmentation transforms
+
+ Returns:
+ A PyTorch compatible Transform
+ """
+ magnitude = 3
+ width = 3
+ depth = -1
+ alpha = 1.
+ blended = False
+ config = config_str.split('-')
+ assert config[0] == 'augmix'
+ config = config[1:]
+ for c in config:
+ cs = re.split(r'(\d.*)', c)
+ if len(cs) < 2:
+ continue
+ key, val = cs[:2]
+ if key == 'mstd':
+ # noise param injected via hparams for now
+ hparams.setdefault('magnitude_std', float(val))
+ elif key == 'm':
+ magnitude = int(val)
+ elif key == 'w':
+ width = int(val)
+ elif key == 'd':
+ depth = int(val)
+ elif key == 'a':
+ alpha = float(val)
+ elif key == 'b':
+ blended = bool(val)
+ else:
+ assert False, 'Unknown AugMix config section'
+ hparams.setdefault(
+ 'magnitude_std',
+ float('inf')) # default to uniform sampling (if not set via mstd arg)
+ ops = augmix_ops(magnitude=magnitude, hparams=hparams)
+ return AugMixAugment(ops,
+ alpha=alpha,
+ width=width,
+ depth=depth,
+ blended=blended)
diff --git a/openrec/preprocess/cam_label_encode.py b/openrec/preprocess/cam_label_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..36317a103ce75c5969b471c164243f3135491ad1
--- /dev/null
+++ b/openrec/preprocess/cam_label_encode.py
@@ -0,0 +1,141 @@
+import numpy as np
+import cv2
+from .ar_label_encode import ARLabelEncode
+
+
+def crop_safe(arr, rect, bbs=[], pad=0):
+ rect = np.array(rect)
+ rect[:2] -= pad
+ rect[2:] += 2 * pad
+ v0 = [max(0, rect[0]), max(0, rect[1])]
+ v1 = [
+ min(arr.shape[0], rect[0] + rect[2]),
+ min(arr.shape[1], rect[1] + rect[3])
+ ]
+ arr = arr[v0[0]:v1[0], v0[1]:v1[1], ...]
+ if len(bbs) > 0:
+ for i in range(len(bbs)):
+ bbs[i, 0] -= v0[0]
+ bbs[i, 1] -= v0[1]
+ return arr, bbs
+ else:
+ return arr
+
+
+try:
+ # pygame==2.5.2
+ import pygame
+ from pygame import freetype
+except:
+ pass
+
+
+class CAMLabelEncode(ARLabelEncode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ font_path=None,
+ font_size=30,
+ font_strength=0.1,
+ image_shape=[32, 128],
+ **kwargs):
+ super(CAMLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ use_space_char)
+ self.image_shape = image_shape
+
+ if font_path is not None:
+ freetype.init()
+ # init font
+ self.font = freetype.Font(font_path)
+ self.font.antialiased = True
+ self.font.origin = True
+
+ # choose font style
+ self.font.size = font_size
+ self.font.underline = False
+
+ self.font.strong = True
+ self.font.strength = font_strength
+ self.font.oblique = False
+
+ def render_normal(self, font, text):
+ # get the number of lines
+ lines = text.split('\n')
+ lengths = [len(l) for l in lines]
+
+ # font parameters:
+ line_spacing = font.get_sized_height() + 1
+
+ # initialize the surface to proper size:
+ line_bounds = font.get_rect(lines[np.argmax(lengths)])
+ fsize = (round(2.0 * line_bounds.width),
+ round(1.25 * line_spacing * len(lines)))
+ surf = pygame.Surface(fsize, pygame.locals.SRCALPHA, 32)
+
+ bbs = []
+ space = font.get_rect('O')
+ # space = font.get_rect(' ')
+ x, y = 0, 0
+ for l in lines:
+ x = 2 # carriage-return
+ y += line_spacing # line-feed
+
+ for ch in l: # render each character
+ if ch.isspace(): # just shift
+ x += space.width
+ else:
+ # render the character
+ ch_bounds = font.render_to(surf, (x, y), ch)
+ # ch_bounds.x = x + ch_bounds.x
+ # ch_bounds.y = y - ch_bounds.y
+ x += ch_bounds.width + 5
+ bbs.append(np.array(ch_bounds))
+
+ # get the union of characters for cropping:
+ r0 = pygame.Rect(bbs[0])
+ rect_union = r0.unionall(bbs)
+
+ # get the words:
+ # words = ' '.join(text.split())
+
+ # crop the surface to fit the text:
+ bbs = np.array(bbs)
+ surf_arr, bbs = crop_safe(pygame.surfarray.pixels_alpha(surf),
+ rect_union,
+ bbs,
+ pad=5)
+ surf_arr = surf_arr.swapaxes(0, 1)
+
+ # self.visualize_bb(surf_arr,bbs)
+ return surf_arr, bbs
+
+ def __call__(self, data):
+ data = super().__call__(data=data)
+ if data is None:
+ return None
+ word = []
+ for c in data['label'][1:data['length'] + 1]:
+ word.append(self.character[c])
+ word = ''.join(word)
+ # binary mask
+ binary_mask, bbs = self.render_normal(self.font, word)
+ cate_aware_surf = np.zeros((binary_mask.shape[0], binary_mask.shape[1],
+ len(self.character) - 3)).astype(np.uint8)
+ for id, bb in zip(data['label'][1:data['length'] + 1], bbs):
+ char_id = id - 1
+ cate_aware_surf[:, :,
+ char_id][bb[1]:bb[1] + bb[3], bb[0]:bb[0] +
+ bb[2]] = binary_mask[bb[1]:bb[1] + bb[3],
+ bb[0]:bb[0] + bb[2]]
+ binary_mask = cate_aware_surf
+ binary_mask = cv2.resize(
+ binary_mask, (self.image_shape[0] // 2, self.image_shape[1] // 2))
+ if np.max(binary_mask) > 0:
+ binary_mask = binary_mask / np.max(binary_mask) # [0 ~ 1]
+ binary_mask = binary_mask.astype(np.float32)
+ data['binary_mask'] = binary_mask
+ return data
diff --git a/openrec/preprocess/ce_label_encode.py b/openrec/preprocess/ce_label_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cf1888f6ecb48a5034d38b3acdb69f826de5225
--- /dev/null
+++ b/openrec/preprocess/ce_label_encode.py
@@ -0,0 +1,116 @@
+import re
+
+import numpy as np
+
+from tools.utils.logging import get_logger
+
+
+class BaseRecLabelEncode(object):
+ """Convert between text-label and text-index."""
+
+ def __init__(
+ self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ lower=False,
+ ):
+ self.max_text_len = max_text_length
+ self.beg_str = 'sos'
+ self.end_str = 'eos'
+ self.lower = lower
+ self.reverse = False
+ if character_dict_path is None:
+ logger = get_logger()
+ logger.warning(
+ 'The character_dict_path is None, model can only recognize number and lower letters'
+ )
+ self.character_str = '0123456789abcdefghijklmnopqrstuvwxyz'
+ dict_character = list(self.character_str)
+ self.lower = True
+ else:
+ self.character_str = []
+ with open(character_dict_path, 'rb') as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip('\n').strip('\r\n')
+ self.character_str.append(line)
+ if use_space_char:
+ self.character_str.append(' ')
+ dict_character = list(self.character_str)
+ if 'arabic' in character_dict_path:
+ self.reverse = True
+ dict_character = self.add_special_char(dict_character)
+ self.dict = {}
+ for i, char in enumerate(dict_character):
+ self.dict[char] = i
+ self.character = dict_character
+
+ def label_reverse(self, text):
+ text_re = []
+ c_current = ''
+ for c in text:
+ if not bool(re.search('[a-zA-Z0-9 :*./%+-١٢٣٤٥٦٧٨٩٠]', c)):
+ if c_current != '':
+ text_re.append(c_current)
+ text_re.append(c)
+ c_current = ''
+ else:
+ c_current += c
+ if c_current != '':
+ text_re.append(c_current)
+
+ return ''.join(text_re[::-1])
+
+ def add_special_char(self, dict_character):
+ return dict_character
+
+ def encode(self, text):
+ """convert text-label into text-index.
+ input:
+ text: text labels of each image. [batch_size]
+
+ output:
+ text: concatenated text index for CTCLoss.
+ [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
+ length: length of each text. [batch_size]
+ """
+ if len(text) == 0 or len(text) > self.max_text_len:
+ return None
+ if self.lower:
+ text = text.lower()
+ text_list = []
+ for char in text:
+ if char not in self.dict:
+ # logger = get_logger()
+ # logger.warning('{} is not in dict'.format(char))
+ continue
+ text_list.append(self.dict[char])
+ if len(text_list) == 0:
+ return None
+ return text_list
+
+
+class CELabelEncode(BaseRecLabelEncode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(CELabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ use_space_char)
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ data['length'] = np.array(len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def add_special_char(self, dict_character):
+ return dict_character
diff --git a/openrec/preprocess/char_label_encode.py b/openrec/preprocess/char_label_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca7648c7384286c5170284a079c81aab9c6ac3ac
--- /dev/null
+++ b/openrec/preprocess/char_label_encode.py
@@ -0,0 +1,36 @@
+import numpy as np
+
+from .ctc_label_encode import BaseRecLabelEncode
+
+
+class CharLabelEncode(BaseRecLabelEncode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(CharLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ use_space_char)
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) > self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ text_char = text + [104] * (self.max_text_len + 1 - len(text))
+ text.insert(0, 2)
+ text.append(3)
+ text = text + [0] * (self.max_text_len + 2 - len(text))
+ data['label'] = np.array(text)
+ data['label_char'] = np.array(text_char)
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = ['blank', '', '', ''] + dict_character
+ return dict_character
diff --git a/openrec/preprocess/cppd_label_encode.py b/openrec/preprocess/cppd_label_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..245ed820c3c8fd7e48e0cb29e7220d3f49b4ae89
--- /dev/null
+++ b/openrec/preprocess/cppd_label_encode.py
@@ -0,0 +1,173 @@
+import random
+
+import numpy as np
+
+from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
+
+
+class CPPDLabelEncode(BaseRecLabelEncode):
+ """Convert between text-label and text-index."""
+
+ def __init__(
+ self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ ch=False,
+ # ch_7000=7000,
+ ignore_index=100,
+ use_sos=False,
+ pos_len=False,
+ **kwargs):
+ self.use_sos = use_sos
+ super(CPPDLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ use_space_char)
+ self.ch = ch
+ self.ignore_index = ignore_index
+ self.pos_len = pos_len
+
+ def __call__(self, data):
+ text = data['label']
+ if self.ch:
+ text, text_node_index, text_node_num = self.encodech(text)
+ if text is None:
+ return None
+ if len(text) > self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ # text.insert(0, 0)
+ if self.pos_len:
+ text_pos_node = [i_ for i_ in range(len(text), -1, -1)
+ ] + [100] * (self.max_text_len - len(text))
+ else:
+ text_pos_node = [1] * (len(text) + 1) + [0] * (
+ self.max_text_len - len(text))
+
+ text.append(0)
+ text + [0] * (self.max_text_len - len(text))
+
+ text = text + [self.ignore_index
+ ] * (self.max_text_len + 1 - len(text))
+
+ data['label'] = np.array(text)
+ data['label_node'] = np.array(text_node_num + text_pos_node)
+ data['label_index'] = np.array(text_node_index)
+ # data['label_ctc'] = np.array(ctc_text)
+ return data
+ else:
+ text, text_char_node, ch_order = self.encode(text)
+
+ if text is None:
+ return None
+ if len(text) > self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ # text.insert(0, 0)
+ if self.pos_len:
+ text_pos_node = [i_ for i_ in range(len(text), -1, -1)
+ ] + [100] * (self.max_text_len - len(text))
+ else:
+ text_pos_node = [1] * (len(text) + 1) + [0] * (
+ self.max_text_len - len(text))
+
+ text.append(0)
+
+ text = text + [self.ignore_index
+ ] * (self.max_text_len + 1 - len(text))
+ data['label'] = np.array(text)
+ data['label_node'] = np.array(text_char_node + text_pos_node)
+ data['label_order'] = np.array(ch_order)
+
+ return data
+
+ def add_special_char(self, dict_character):
+ if self.use_sos:
+ dict_character = ['', ''] + dict_character
+ else:
+ dict_character = [''] + dict_character
+ self.num_character = len(dict_character)
+
+ return dict_character
+
+ def encode(self, text):
+ """convert text-label into text-index.
+ input:
+ text: text labels of each image. [batch_size]
+
+ output:
+ text: concatenated text index for CTCLoss.
+ [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
+ length: length of each text. [batch_size]
+ """
+ if len(text) == 0:
+ return None, None, None
+ if self.lower:
+ text = text.lower()
+ text_node = [0 for _ in range(self.num_character)]
+ text_node[0] = 1
+ text_list = []
+ ch_order = []
+ order = 1
+ for char in text:
+ if char not in self.dict:
+ continue
+ text_list.append(self.dict[char])
+ text_node[self.dict[char]] += 1
+ ch_order.append(
+ [self.dict[char], text_node[self.dict[char]], order])
+ order += 1
+
+ no_ch_order = []
+ for char in self.character:
+ if char not in text:
+ no_ch_order.append([self.dict[char], 1, 0])
+ random.shuffle(no_ch_order)
+ ch_order = ch_order + no_ch_order
+ ch_order = ch_order[:self.max_text_len + 1]
+
+ if len(text_list) == 0 or len(text_list) > self.max_text_len:
+ return None, None, None
+ return text_list, text_node, ch_order.sort()
+
+ def encodech(self, text):
+ """convert text-label into text-index.
+ input:
+ text: text labels of each image. [batch_size]
+
+ output:
+ text: concatenated text index for CTCLoss.
+ [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
+ length: length of each text. [batch_size]
+ """
+ if len(text) == 0:
+ return None, None, None
+ if self.lower:
+ text = text.lower()
+ text_node_dict = {}
+ text_node_dict.update({0: 1})
+ character_index = [_ for _ in range(self.num_character)]
+ text_list = []
+ for char in text:
+ if char not in self.dict:
+ continue
+ i_c = self.dict[char]
+ text_list.append(i_c)
+
+ if i_c in text_node_dict.keys():
+ text_node_dict[i_c] += 1
+ else:
+ text_node_dict.update({i_c: 1})
+ for ic in list(text_node_dict.keys()):
+ character_index.remove(ic)
+ none_char_index = random.sample(character_index,
+ 37 - len(list(text_node_dict.keys())))
+ for ic in none_char_index:
+ text_node_dict[ic] = 0
+
+ text_node_index = sorted(text_node_dict)
+
+ text_node_num = [text_node_dict[k] for k in text_node_index]
+ if len(text_list) == 0 or len(text_list) > self.max_text_len:
+ return None, None, None
+ return text_list, text_node_index, text_node_num
diff --git a/openrec/preprocess/ctc_label_encode.py b/openrec/preprocess/ctc_label_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..8441fd12deb7b0d845a2f51fc3f30353488470fc
--- /dev/null
+++ b/openrec/preprocess/ctc_label_encode.py
@@ -0,0 +1,124 @@
+import re
+
+import numpy as np
+
+from tools.utils.logging import get_logger
+
+
+class BaseRecLabelEncode(object):
+ """Convert between text-label and text-index."""
+
+ def __init__(
+ self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ lower=False,
+ ):
+ self.max_text_len = max_text_length
+ self.beg_str = 'sos'
+ self.end_str = 'eos'
+ self.lower = lower
+ self.reverse = False
+ if character_dict_path is None:
+ logger = get_logger()
+ logger.warning(
+ 'The character_dict_path is None, model can only recognize number and lower letters'
+ )
+ self.character_str = '0123456789abcdefghijklmnopqrstuvwxyz'
+ dict_character = list(self.character_str)
+ self.lower = True
+ else:
+ self.character_str = []
+ with open(character_dict_path, 'rb') as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip('\n').strip('\r\n')
+ self.character_str.append(line)
+ if use_space_char:
+ self.character_str.append(' ')
+ dict_character = list(self.character_str)
+ if 'arabic' in character_dict_path:
+ self.reverse = True
+ dict_character = self.add_special_char(dict_character)
+ self.dict = {}
+ for i, char in enumerate(dict_character):
+ self.dict[char] = i
+ self.character = dict_character
+
+ def label_reverse(self, text):
+ text_re = []
+ c_current = ''
+ for c in text:
+ if not bool(re.search('[a-zA-Z0-9 :*./%+-١٢٣٤٥٦٧٨٩٠]', c)):
+ if c_current != '':
+ text_re.append(c_current)
+ text_re.append(c)
+ c_current = ''
+ else:
+ c_current += c
+ if c_current != '':
+ text_re.append(c_current)
+
+ return ''.join(text_re[::-1])
+
+ def add_special_char(self, dict_character):
+ return dict_character
+
+ def encode(self, text):
+ """convert text-label into text-index.
+ input:
+ text: text labels of each image. [batch_size]
+
+ output:
+ text: concatenated text index for CTCLoss.
+ [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
+ length: length of each text. [batch_size]
+ """
+ if len(text) == 0:
+ return None
+ if self.lower:
+ text = text.lower()
+ text_list = []
+ for char in text:
+ if char not in self.dict:
+ continue
+ text_list.append(self.dict[char])
+ if len(text_list) == 0 or len(text_list) > self.max_text_len:
+ return None
+ return text_list
+
+
+class CTCLabelEncode(BaseRecLabelEncode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(CTCLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ use_space_char)
+ self.is_reverse = kwargs.get('is_reverse', False)
+
+ def __call__(self, data):
+ text = data['label']
+ if self.reverse and self.is_reverse: # for arabic rec
+ text = self.label_reverse(text)
+ text = self.encode(text)
+ if text is None:
+ return None
+ data['length'] = np.array(len(text))
+ text = text + [0] * (self.max_text_len - len(text))
+ data['label'] = np.array(text)
+
+ label = [0] * len(self.character)
+ for x in text:
+ label[x] += 1
+ data['label_ace'] = np.array(label)
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = ['blank'] + dict_character
+ return dict_character
diff --git a/openrec/preprocess/ep_label_encode.py b/openrec/preprocess/ep_label_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..78ec218e60642ed5ae71cfdbac44495f610bfff1
--- /dev/null
+++ b/openrec/preprocess/ep_label_encode.py
@@ -0,0 +1,38 @@
+import numpy as np
+
+from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
+
+
+class EPLabelEncode(BaseRecLabelEncode):
+ """Convert between text-label and text-index."""
+ EOS = ''
+ PAD = ''
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+
+ super(EPLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ use_space_char)
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) > self.max_text_len:
+ return None
+
+ data['length'] = np.array(len(text))
+ text = text + [self.dict[self.EOS]]
+ text = text + [self.dict[self.PAD]
+ ] * (self.max_text_len + 1 - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = [self.EOS] + dict_character + [self.PAD]
+ return dict_character
diff --git a/openrec/preprocess/igtr_label_encode.py b/openrec/preprocess/igtr_label_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..50020d3f5640c245d84d1d64dfcf99488cdc8524
--- /dev/null
+++ b/openrec/preprocess/igtr_label_encode.py
@@ -0,0 +1,360 @@
+import copy
+import random
+
+import numpy as np
+
+from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
+
+
+class IGTRLabelEncode(BaseRecLabelEncode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ k=1,
+ ch=False,
+ prompt_error=False,
+ **kwargs):
+ super(IGTRLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ use_space_char)
+ self.ignore_index = self.dict['']
+ self.k = k
+ self.prompt_error = prompt_error
+ self.ch = ch
+ rare_file = kwargs.get('rare_file', None)
+ siml_file = kwargs.get('siml_file', None)
+ siml_char_dict = {}
+ siml_char_list = [0 for _ in range(self.num_character)]
+ if siml_file is not None:
+ with open(siml_file, 'r') as f:
+ for lin in f.readlines():
+ lin_s = lin.strip().split('\t')
+ char_siml = lin_s[0]
+ if char_siml in self.dict:
+ siml_list = []
+ siml_prob = []
+ for i in range(1, len(lin_s), 2):
+ c = lin_s[i]
+ prob = int(lin_s[i + 1])
+ if c in self.dict and prob >= 1:
+ siml_list.append(self.dict[c])
+ siml_prob.append(prob)
+ siml_prob = np.array(siml_prob,
+ dtype=np.float32) / sum(siml_prob)
+ siml_char_dict[self.dict[char_siml]] = [
+ siml_list, siml_prob.tolist()
+ ]
+ siml_char_list[self.dict[char_siml]] = 1
+ self.siml_char_dict = siml_char_dict
+ self.siml_char_list = siml_char_list
+
+ rare_char_list = [0 for _ in range(self.num_character)]
+ if rare_file is not None:
+ with open(rare_file, 'r') as f:
+ for lin in f.readlines():
+ lin_s = lin.strip().split('\t')
+ # print(lin_s)
+ char_rare = lin_s[0]
+ num_appear = int(lin_s[1])
+ if char_rare in self.dict and num_appear < 1000:
+ rare_char_list[self.dict[char_rare]] = 1
+
+ self.rare_char_list = rare_char_list # [self.dict[char] for char in rare_char_list]
+
+ def __call__(self, data):
+ text = data['label'] # coffee
+
+ encoder_result = self.encode(text)
+ if encoder_result is None:
+ return None
+
+ text, text_char_num, ques_list_s, prompt_list_s = encoder_result
+
+ if len(text) > self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+
+ text = [self.dict['']] + text + [self.dict['']]
+ text = text + [self.dict['']
+ ] * (self.max_text_len + 2 - len(text))
+ data['label'] = np.array(text) # 6
+
+ ques_len_list = []
+ ques2_len_list = []
+ prompt_len_list = []
+
+ prompt_pos_idx_list = []
+ prompt_char_idx_list = []
+ ques_pos_idx_list = []
+ ques1_answer_list = []
+ ques2_char_idx_list = []
+ ques2_answer_list = []
+ ques4_char_num_list = []
+ train_step = 0
+ for prompt_list, ques_list in zip(prompt_list_s, ques_list_s):
+
+ prompt_len = len(prompt_list) + 1
+ prompt_len_list.append(prompt_len)
+ prompt_list = np.array(
+ [[0, self.dict[''], 0]] + prompt_list +
+ [[self.max_text_len + 2, self.dict[''], 0]] *
+ (self.max_text_len - len(prompt_list)))
+ prompt_pos_idx_list.append(prompt_list[:, 0])
+ prompt_char_idx_list.append(prompt_list[:, 1])
+
+ ques_len = len(ques_list)
+ ques_len_list.append(ques_len)
+
+ ques_list = np.array(
+ ques_list + [[self.max_text_len + 2, self.dict[''], 0]] *
+ (self.max_text_len + 1 - ques_len))
+ ques_pos_idx_list.append(ques_list[:, 0])
+ # what is the first and third char?
+ # Is the first character 't'? and Is the third character 'f'?
+ # How many 'c', 's' and 'f' are there in the text image?
+ ques1_answer_list.append(ques_list[:, 1])
+ ques2_char_idx = copy.deepcopy(ques_list[:ques_len, :2])
+ new_ques2_char_idx = []
+ ques2_answer = []
+ for q_2, ques2_idx in enumerate(ques2_char_idx.tolist()):
+
+ if (train_step == 2 or train_step == 3) and q_2 == ques_len - 1:
+ new_ques2_char_idx.append(ques2_idx)
+ ques2_answer.append(1)
+ continue
+ if ques2_idx[1] != self.dict[''] and random.random() > 0.5:
+ select_idx = random.randint(0, self.num_character - 3)
+ new_ques2_char_idx.append([ques2_idx[0], select_idx])
+ if select_idx == ques2_idx[1]:
+ ques2_answer.append(1)
+ else:
+ ques2_answer.append(0)
+
+ if self.siml_char_list[
+ ques2_idx[1]] == 1 and random.random() > 0.5:
+ select_idx_sim_list = random.sample(
+ self.siml_char_dict[ques2_idx[1]][0],
+ min(3, len(self.siml_char_dict[ques2_idx[1]][0])),
+ )
+ for select_idx in select_idx_sim_list:
+ new_ques2_char_idx.append(
+ [ques2_idx[0], select_idx])
+ if select_idx == ques2_idx[1]:
+ ques2_answer.append(1)
+ else:
+ ques2_answer.append(0)
+ else:
+ new_ques2_char_idx.append(ques2_idx)
+ ques2_answer.append(1)
+ ques2_len_list.append(len(new_ques2_char_idx))
+ ques2_char_idx_new = np.array(
+ new_ques2_char_idx +
+ [[self.max_text_len + 2, self.dict['']]] *
+ (self.max_text_len * 4 + 1 - len(new_ques2_char_idx)))
+ ques2_answer = np.array(
+ ques2_answer + [0] *
+ (self.max_text_len * 4 + 1 - len(ques2_answer)))
+ ques2_char_idx_list.append(ques2_char_idx_new)
+ ques2_answer_list.append(ques2_answer)
+
+ ques4_char_num_list.append(ques_list[:, 2])
+ train_step += 1
+
+ data['ques_len_list'] = np.array(ques_len_list, dtype=np.int64)
+ data['ques2_len_list'] = np.array(ques2_len_list, dtype=np.int64)
+ data['prompt_len_list'] = np.array(prompt_len_list, dtype=np.int64)
+
+ data['prompt_pos_idx_list'] = np.array(prompt_pos_idx_list,
+ dtype=np.int64)
+ data['prompt_char_idx_list'] = np.array(prompt_char_idx_list,
+ dtype=np.int64)
+ data['ques_pos_idx_list'] = np.array(ques_pos_idx_list, dtype=np.int64)
+ data['ques1_answer_list'] = np.array(ques1_answer_list, dtype=np.int64)
+ data['ques2_char_idx_list'] = np.array(ques2_char_idx_list,
+ dtype=np.int64)
+ data['ques2_answer_list'] = np.array(ques2_answer_list,
+ dtype=np.float32)
+
+ data['ques3_answer'] = np.array(
+ text_char_num,
+ dtype=np.int64) # np.array([1, 0, 2]) # answer 1, 0, 2
+ data['ques4_char_num_list'] = np.array(ques4_char_num_list)
+
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = [''] + dict_character + [''] + ['']
+ self.num_character = len(dict_character)
+
+ return dict_character
+
+ def encode(self, text):
+ """convert text-label into text-index.
+ input:
+ text: text labels of each image. [batch_size]
+
+ output:
+ text: concatenated text index for CTCLoss.
+ [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
+ length: length of each text. [batch_size]
+ """
+ if len(text) == 0:
+ return None
+ if self.lower:
+ text = text.lower()
+ char_num = [0 for _ in range(self.num_character - 2)]
+ char_num[0] = 1
+ text_list = []
+ qa_text = []
+ pos_i = 0
+ rare_char_qa = []
+ unrare_char_qa = []
+ for char in text:
+ if char not in self.dict:
+ continue
+
+ char_id = self.dict[char]
+ text_list.append(char_id)
+ qa_text.append([pos_i + 1, char_id, char_num[char_id]])
+ if self.rare_char_list[char_id] == 1:
+ rare_char_qa.append([pos_i + 1, char_id, char_num[char_id]])
+ else:
+ unrare_char_qa.append([pos_i + 1, char_id, char_num[char_id]])
+ char_num[char_id] += 1
+ pos_i += 1
+
+ if self.ch:
+ char_num_ch = []
+ char_num_ch_none = []
+ rare_char_num_ch_none = []
+ for i, num in enumerate(char_num):
+ if self.rare_char_list[i] == 1:
+ rare_char_num_ch_none.append([i, num])
+ if num > 0:
+ char_num_ch.append([i, num])
+ else:
+ char_num_ch_none.append([i, 0])
+ none_char_index = random.sample(
+ char_num_ch_none,
+ min(37 - len(char_num_ch), len(char_num_ch_none)))
+ if len(rare_char_num_ch_none) > 0:
+ none_rare_char_index = random.sample(
+ rare_char_num_ch_none,
+ min(40 - len(char_num_ch) - len(none_char_index),
+ len(rare_char_num_ch_none)),
+ )
+ char_num_ch = char_num_ch + none_char_index + none_rare_char_index
+ else:
+ char_num_ch = char_num_ch + none_char_index
+ char_num_ch.sort(key=lambda x: x[0])
+ char_num = char_num_ch
+
+ len_ = len(text_list)
+ if len_ == 0:
+ return None
+ ques_list = [
+ qa_text + [[pos_i + 1, self.dict[''], 0]],
+ [[pos_i + 1, self.dict[''], 0]],
+ ]
+ prompt_list = [qa_text[len_:], qa_text]
+ if len_ == 1:
+ ques_list.append([[self.max_text_len + 1, self.dict[''], 0]])
+ prompt_list.append(
+ [[self.max_text_len + 2, self.dict[''], 0]] * 4 + qa_text)
+ for _ in range(1, self.k):
+ ques_list.append(
+ [[self.max_text_len + 2, self.dict[''], 0]])
+ prompt_list.append(qa_text[1:])
+ else:
+
+ next_id = random.sample(range(1, len_ + 1), 2)
+ for slice_id in next_id:
+ b_i = slice_id - 5 if slice_id - 5 > 0 else 0
+ if slice_id == len_:
+ ques_list.append(
+ [[self.max_text_len + 1, self.dict[''], 0]])
+ else:
+ ques_list.append(
+ qa_text[slice_id:] +
+ [[self.max_text_len + 1, qa_text[slice_id][1], 0]])
+ prompt_list.append(
+ [[self.max_text_len + 2, self.dict[''], 0]] *
+ (5 - slice_id + b_i) + qa_text[b_i:slice_id])
+
+ shuffle_id1 = random.sample(range(1, len_),
+ 2) if len_ > 2 else [1, 0]
+ for slice_id in shuffle_id1:
+ if slice_id == 0:
+ ques_list.append(
+ [[self.max_text_len + 2, self.dict[''], 0]])
+ prompt_list.append(qa_text[:0])
+ else:
+ ques_list.append(qa_text[slice_id:] +
+ [[pos_i + 1, self.dict[''], 0]])
+ prompt_list.append(qa_text[:slice_id])
+
+ if len_ > 2:
+ shuffle_id2 = random.sample(
+ range(1, len_),
+ self.k - 4 if len_ - 1 > self.k - 4 else len_ - 1)
+ if self.k - 4 != len(shuffle_id2):
+ shuffle_id2 += random.sample(range(1, len_),
+ self.k - 4 - len(shuffle_id2))
+ rare_slice_id = len(rare_char_qa)
+ unrare_slice_id = len(unrare_char_qa)
+ for slice_id in shuffle_id2:
+ random.shuffle(qa_text)
+ if len(rare_char_qa) > 0 and random.random() < 0.5:
+ ques_list.append(rare_char_qa[:rare_slice_id] +
+ unrare_char_qa[unrare_slice_id:] +
+ [[pos_i + 1, self.dict[''], 0]])
+ if len(unrare_char_qa[:unrare_slice_id]) > 0:
+ prompt_list1 = random.sample(
+ unrare_char_qa[:unrare_slice_id],
+ random.randint(
+ 1, len(unrare_char_qa[:unrare_slice_id]))
+ if len(unrare_char_qa[:unrare_slice_id]) > 1
+ else 1,
+ )
+ else:
+ prompt_list1 = []
+ if len(rare_char_qa[rare_slice_id:]) > 0:
+ prompt_list2 = random.sample(
+ rare_char_qa[rare_slice_id:],
+ random.randint(
+ 1,
+ len(rare_char_qa[rare_slice_id:])
+ if len(rare_char_qa[rare_slice_id:]) > 1
+ else 1,
+ ),
+ )
+ else:
+ prompt_list2 = []
+ prompt_list.append(prompt_list1 + prompt_list2)
+ random.shuffle(rare_char_qa)
+ random.shuffle(unrare_char_qa)
+ rare_slice_id = random.randint(
+ 1,
+ len(rare_char_qa)) if len(rare_char_qa) > 1 else 1
+ unrare_slice_id = random.randint(
+ 1, len(unrare_char_qa)
+ ) if len(unrare_char_qa) > 1 else 1
+ else:
+ ques_list.append(qa_text[slice_id:] +
+ [[pos_i + 1, self.dict[''], 0]])
+ prompt_list.append(qa_text[:slice_id])
+ else:
+ ques_list.append(qa_text[1:] +
+ [[pos_i + 1, self.dict[''], 0]])
+ prompt_list.append(qa_text[:1])
+ ques_list.append(qa_text[:1] +
+ [[pos_i + 1, self.dict[''], 0]])
+ prompt_list.append(qa_text[1:])
+ ques_list += [[[self.max_text_len + 2, self.dict[''], 0]]
+ ] * (self.k - 6)
+ prompt_list += [qa_text[:0]] * (self.k - 6)
+
+ return text_list, char_num, ques_list, prompt_list
diff --git a/openrec/preprocess/mgp_label_encode.py b/openrec/preprocess/mgp_label_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..440280ebcc301301822e4a01eb2c04ecc83f14e4
--- /dev/null
+++ b/openrec/preprocess/mgp_label_encode.py
@@ -0,0 +1,95 @@
+'''
+This code is refer from:
+https://github.com/AlibabaResearch/AdvancedLiterateMachinery/blob/main/OCR/MGP-STR
+'''
+import numpy as np
+
+from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
+
+
+class MGPLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+ SPACE = '[s]'
+ GO = '[GO]'
+ list_token = [GO, SPACE]
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ only_char=False,
+ **kwargs):
+ super(MGPLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ use_space_char)
+ # character (str): set of the possible characters.
+ # [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
+
+ self.batch_max_length = max_text_length + len(self.list_token)
+ self.only_char = only_char
+ if not only_char:
+ # transformers==4.2.1
+ from transformers import BertTokenizer, GPT2Tokenizer
+ self.bpe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ self.wp_tokenizer = BertTokenizer.from_pretrained(
+ 'bert-base-uncased')
+
+ def __call__(self, data):
+ text = data['label']
+ char_text, char_len = self.encode(text)
+ if char_text is None:
+ return None
+ data['length'] = np.array(char_len)
+ data['char_label'] = np.array(char_text)
+ if self.only_char:
+ return data
+ bpe_text = self.bpe_encode(text)
+ if bpe_text is None:
+ return None
+ wp_text = self.wp_encode(text)
+ data['bpe_label'] = np.array(bpe_text)
+ data['wp_label'] = wp_text
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = self.list_token + dict_character
+ return dict_character
+
+ def encode(self, text):
+ """ convert text-label into text-index.
+ """
+ if len(text) == 0:
+ return None, None
+ if self.lower:
+ text = text.lower()
+ length = len(text)
+ text = [self.GO] + list(text) + [self.SPACE]
+ text_list = []
+ for char in text:
+ if char not in self.dict:
+ continue
+ text_list.append(self.dict[char])
+ if len(text_list) == 0 or len(text_list) > self.batch_max_length:
+ return None, None
+ text_list = text_list + [self.dict[self.GO]
+ ] * (self.batch_max_length - len(text_list))
+ return text_list, length
+
+ def bpe_encode(self, text):
+ if len(text) == 0:
+ return None
+ token = self.bpe_tokenizer(text)['input_ids']
+ text_list = [1] + token + [2]
+ if len(text_list) == 0 or len(text_list) > self.batch_max_length:
+ return None
+ text_list = text_list + [self.dict[self.GO]
+ ] * (self.batch_max_length - len(text_list))
+ return text_list
+
+ def wp_encode(self, text):
+ wp_target = self.wp_tokenizer([text],
+ padding='max_length',
+ max_length=self.batch_max_length,
+ truncation=True,
+ return_tensors='np')
+ return wp_target['input_ids'][0]
diff --git a/openrec/preprocess/parseq_aug.py b/openrec/preprocess/parseq_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..537bd5940842689ba2c25415f9d098fc46a64f65
--- /dev/null
+++ b/openrec/preprocess/parseq_aug.py
@@ -0,0 +1,150 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+
+import imgaug.augmenters as iaa
+import numpy as np
+from PIL import Image, ImageFilter
+
+from openrec.preprocess import auto_augment
+from openrec.preprocess.auto_augment import _LEVEL_DENOM, LEVEL_TO_ARG, NAME_TO_OP, _randomly_negate, rotate
+
+
+def rotate_expand(img, degrees, **kwargs):
+ """Rotate operation with expand=True to avoid cutting off the
+ characters."""
+ kwargs['expand'] = True
+ return rotate(img, degrees, **kwargs)
+
+
+def _level_to_arg(level, hparams, key, default):
+ magnitude = hparams.get(key, default)
+ level = (level / _LEVEL_DENOM) * magnitude
+ level = _randomly_negate(level)
+ return level,
+
+
+def apply():
+ # Overrides
+ NAME_TO_OP.update({'Rotate': rotate_expand})
+ LEVEL_TO_ARG.update({
+ 'Rotate':
+ partial(_level_to_arg, key='rotate_deg', default=30.),
+ 'ShearX':
+ partial(_level_to_arg, key='shear_x_pct', default=0.3),
+ 'ShearY':
+ partial(_level_to_arg, key='shear_y_pct', default=0.3),
+ 'TranslateXRel':
+ partial(_level_to_arg, key='translate_x_pct', default=0.45),
+ 'TranslateYRel':
+ partial(_level_to_arg, key='translate_y_pct', default=0.45),
+ })
+
+
+apply()
+
+_OP_CACHE = {}
+
+
+def _get_op(key, factory):
+ try:
+ op = _OP_CACHE[key]
+ except KeyError:
+ op = factory()
+ _OP_CACHE[key] = op
+ return op
+
+
+def _get_param(level, img, max_dim_factor, min_level=1):
+ max_level = max(min_level, max_dim_factor * max(img.size))
+ return round(min(level, max_level))
+
+
+def gaussian_blur(img, radius, **__):
+ radius = _get_param(radius, img, 0.02)
+ key = 'gaussian_blur_' + str(radius)
+ op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius))
+ return img.filter(op)
+
+
+def motion_blur(img, k, **__):
+ k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values
+ key = 'motion_blur_' + str(k)
+ op = _get_op(key, lambda: iaa.MotionBlur(k))
+ return Image.fromarray(op(image=np.asarray(img)))
+
+
+def gaussian_noise(img, scale, **_):
+ scale = _get_param(scale, img, 0.25) | 1 # bin to odd values
+ key = 'gaussian_noise_' + str(scale)
+ op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale))
+ return Image.fromarray(op(image=np.asarray(img)))
+
+
+def poisson_noise(img, lam, **_):
+ lam = _get_param(lam, img, 0.2) | 1 # bin to odd values
+ key = 'poisson_noise_' + str(lam)
+ op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam))
+ return Image.fromarray(op(image=np.asarray(img)))
+
+
+def _level_to_arg(level, _hparams, max):
+ level = max * level / auto_augment._LEVEL_DENOM
+ return level,
+
+
+_RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy()
+_RAND_TRANSFORMS.remove(
+ 'SharpnessIncreasing') # remove, interferes with *blur ops
+_RAND_TRANSFORMS.extend([
+ 'GaussianBlur',
+ # 'MotionBlur',
+ # 'GaussianNoise',
+ 'PoissonNoise'
+])
+auto_augment.LEVEL_TO_ARG.update({
+ 'GaussianBlur':
+ partial(_level_to_arg, max=4),
+ 'MotionBlur':
+ partial(_level_to_arg, max=20),
+ 'GaussianNoise':
+ partial(_level_to_arg, max=0.1 * 255),
+ 'PoissonNoise':
+ partial(_level_to_arg, max=40)
+})
+auto_augment.NAME_TO_OP.update({
+ 'GaussianBlur': gaussian_blur,
+ 'MotionBlur': motion_blur,
+ 'GaussianNoise': gaussian_noise,
+ 'PoissonNoise': poisson_noise
+})
+
+
+def rand_augment_transform(magnitude=5, num_layers=3):
+ # These are tuned for magnitude=5, which means that effective magnitudes are half of these values.
+ hparams = {
+ 'rotate_deg': 30,
+ 'shear_x_pct': 0.9,
+ 'shear_y_pct': 0.2,
+ 'translate_x_pct': 0.10,
+ 'translate_y_pct': 0.30
+ }
+ ra_ops = auto_augment.rand_augment_ops(magnitude,
+ hparams=hparams,
+ transforms=_RAND_TRANSFORMS)
+ # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice)
+ choice_weights = [1. / len(ra_ops) for _ in range(len(ra_ops))]
+ return auto_augment.RandAugment(ra_ops, num_layers, choice_weights)
diff --git a/openrec/preprocess/rec_aug.py b/openrec/preprocess/rec_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdc882a90c6e602ff7e4803a1fdddc64566b5941
--- /dev/null
+++ b/openrec/preprocess/rec_aug.py
@@ -0,0 +1,211 @@
+import random
+
+import cv2
+import numpy as np
+from PIL import Image
+from torchvision.transforms import Compose
+
+from .abinet_aug import CVColorJitter, CVDeterioration, CVGeometry, SVTRDeterioration, SVTRGeometry
+from .parseq_aug import rand_augment_transform
+
+
+class PARSeqAugPIL(object):
+
+ def __init__(self, **kwargs):
+ self.transforms = rand_augment_transform()
+
+ def __call__(self, data):
+ img = data['image']
+ img_aug = self.transforms(img)
+ data['image'] = img_aug
+ return data
+
+
+class PARSeqAug(object):
+
+ def __init__(self, **kwargs):
+ self.transforms = rand_augment_transform()
+
+ def __call__(self, data):
+ img = data['image']
+
+ img = np.array(self.transforms(Image.fromarray(img)))
+ data['image'] = img
+ return data
+
+
+class ABINetAug(object):
+
+ def __init__(self,
+ geometry_p=0.5,
+ deterioration_p=0.25,
+ colorjitter_p=0.25,
+ **kwargs):
+ self.transforms = Compose([
+ CVGeometry(
+ degrees=45,
+ translate=(0.0, 0.0),
+ scale=(0.5, 2.0),
+ shear=(45, 15),
+ distortion=0.5,
+ p=geometry_p,
+ ),
+ CVDeterioration(var=20, degrees=6, factor=4, p=deterioration_p),
+ CVColorJitter(
+ brightness=0.5,
+ contrast=0.5,
+ saturation=0.5,
+ hue=0.1,
+ p=colorjitter_p,
+ ),
+ ])
+
+ def __call__(self, data):
+ img = data['image']
+ img = self.transforms(img)
+ data['image'] = img
+ return data
+
+
+class SVTRAug(object):
+
+ def __init__(self,
+ aug_type=0,
+ geometry_p=0.5,
+ deterioration_p=0.25,
+ colorjitter_p=0.25,
+ **kwargs):
+ self.transforms = Compose([
+ SVTRGeometry(
+ aug_type=aug_type,
+ degrees=45,
+ translate=(0.0, 0.0),
+ scale=(0.5, 2.0),
+ shear=(45, 15),
+ distortion=0.5,
+ p=geometry_p,
+ ),
+ SVTRDeterioration(var=20, degrees=6, factor=4, p=deterioration_p),
+ CVColorJitter(
+ brightness=0.5,
+ contrast=0.5,
+ saturation=0.5,
+ hue=0.1,
+ p=colorjitter_p,
+ ),
+ ])
+
+ def __call__(self, data):
+ img = data['image']
+ img = self.transforms(img)
+ data['image'] = img
+ return data
+
+
+class BaseDataAugmentation(object):
+
+ def __init__(self,
+ crop_prob=0.4,
+ reverse_prob=0.4,
+ noise_prob=0.4,
+ jitter_prob=0.4,
+ blur_prob=0.4,
+ hsv_aug_prob=0.4,
+ **kwargs):
+ self.crop_prob = crop_prob
+ self.reverse_prob = reverse_prob
+ self.noise_prob = noise_prob
+ self.jitter_prob = jitter_prob
+ self.blur_prob = blur_prob
+ self.hsv_aug_prob = hsv_aug_prob
+ # for GaussianBlur
+ self.fil = cv2.getGaussianKernel(ksize=5, sigma=1, ktype=cv2.CV_32F)
+
+ def __call__(self, data):
+ img = data['image']
+ h, w, _ = img.shape
+
+ if random.random() <= self.crop_prob and h >= 20 and w >= 20:
+ img = get_crop(img)
+
+ if random.random() <= self.blur_prob:
+ # GaussianBlur
+ img = cv2.sepFilter2D(img, -1, self.fil, self.fil)
+
+ if random.random() <= self.hsv_aug_prob:
+ img = hsv_aug(img)
+
+ if random.random() <= self.jitter_prob:
+ img = jitter(img)
+
+ if random.random() <= self.noise_prob:
+ img = add_gasuss_noise(img)
+
+ if random.random() <= self.reverse_prob:
+ img = 255 - img
+
+ data['image'] = img
+ return data
+
+
+def hsv_aug(img):
+ """cvtColor."""
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
+ delta = 0.001 * random.random() * flag()
+ hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta)
+ new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
+ return new_img
+
+
+def blur(img):
+ """blur."""
+ h, w, _ = img.shape
+ if h > 10 and w > 10:
+ return cv2.GaussianBlur(img, (5, 5), 1)
+ else:
+ return img
+
+
+def jitter(img):
+ """jitter."""
+ w, h, _ = img.shape
+ if h > 10 and w > 10:
+ thres = min(w, h)
+ s = int(random.random() * thres * 0.01)
+ src_img = img.copy()
+ for i in range(s):
+ img[i:, i:, :] = src_img[:w - i, :h - i, :]
+ return img
+ else:
+ return img
+
+
+def add_gasuss_noise(image, mean=0, var=0.1):
+ """Gasuss noise."""
+
+ noise = np.random.normal(mean, var**0.5, image.shape)
+ out = image + 0.5 * noise
+ out = np.clip(out, 0, 255)
+ out = np.uint8(out)
+ return out
+
+
+def get_crop(image):
+ """random crop."""
+ h, w, _ = image.shape
+ top_min = 1
+ top_max = 8
+ top_crop = int(random.randint(top_min, top_max))
+ top_crop = min(top_crop, h - 1)
+ crop_img = image.copy()
+ ratio = random.randint(0, 1)
+ if ratio:
+ crop_img = crop_img[top_crop:h, :, :]
+ else:
+ crop_img = crop_img[0:h - top_crop, :, :]
+ return crop_img
+
+
+def flag():
+ """flag."""
+ return 1 if random.random() > 0.5000001 else -1
diff --git a/openrec/preprocess/resize.py b/openrec/preprocess/resize.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfdb615f9cd4d3b698c49d42647a4a6c5b278190
--- /dev/null
+++ b/openrec/preprocess/resize.py
@@ -0,0 +1,534 @@
+import math
+import random
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from torchvision import transforms as T
+from torchvision.transforms import functional as F
+
+
+class CDistNetResize(object):
+
+ def __init__(self, image_shape, **kwargs):
+ self.image_shape = image_shape
+
+ def __call__(self, data):
+ img = data['image']
+ _, h, w = self.image_shape
+ # keep_aspect_ratio = False
+ image_pil = Image.fromarray(np.uint8(img))
+ image = image_pil.resize((w, h), Image.LANCZOS)
+ image = np.array(image)
+ # rgb2gray = False
+ image = image.transpose((2, 0, 1))
+ image = image.astype(np.float32) / 128.0 - 1.0
+ data['image'] = image
+ data['valid_ratio'] = 1
+ return data
+
+
+class ABINetResize(object):
+
+ def __init__(self, image_shape, **kwargs):
+ self.image_shape = image_shape
+
+ def __call__(self, data):
+ img = data['image']
+ h, w = img.shape[:2]
+ norm_img, valid_ratio = resize_norm_img_abinet(img, self.image_shape)
+ data['image'] = norm_img
+ data['valid_ratio'] = valid_ratio
+ r = float(w) / float(h)
+ data['real_ratio'] = max(1, round(r))
+ return data
+
+
+def resize_norm_img_abinet(img, image_shape):
+ imgC, imgH, imgW = image_shape
+
+ resized_image = cv2.resize(img, (imgW, imgH),
+ interpolation=cv2.INTER_LINEAR)
+ resized_w = imgW
+ resized_image = resized_image.astype('float32')
+ resized_image = resized_image / 255.0
+
+ mean = np.array([0.485, 0.456, 0.406])
+ std = np.array([0.229, 0.224, 0.225])
+ resized_image = (resized_image - mean[None, None, ...]) / std[None, None,
+ ...]
+ resized_image = resized_image.transpose((2, 0, 1))
+ resized_image = resized_image.astype('float32')
+
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ return resized_image, valid_ratio
+
+
+class SVTRResize(object):
+
+ def __init__(self, image_shape, padding=True, **kwargs):
+ self.image_shape = image_shape
+ self.padding = padding
+
+ def __call__(self, data):
+ img = data['image']
+ h, w = img.shape[:2]
+ norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
+ self.padding)
+ data['image'] = norm_img
+ data['valid_ratio'] = valid_ratio
+ r = float(w) / float(h)
+ data['real_ratio'] = max(1, round(r))
+ return data
+
+
+class RecTVResize(object):
+
+ def __init__(self, image_shape=[32, 128], padding=True, **kwargs):
+ self.padding = padding
+ self.image_shape = image_shape
+ self.interpolation = T.InterpolationMode.BICUBIC
+ transforms = []
+ transforms.extend([
+ T.ToTensor(),
+ T.Normalize(0.5, 0.5),
+ ])
+ self.transforms = T.Compose(transforms)
+
+ def __call__(self, data):
+ img = data['image']
+ imgH, imgW = self.image_shape
+ w, h = img.size
+ if not self.padding:
+ resized_w = imgW
+ else:
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = F.resize(img, (imgH, resized_w),
+ interpolation=self.interpolation)
+ img = self.transforms(resized_image)
+ if resized_w < imgW:
+ img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.)
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ data['image'] = img
+ data['valid_ratio'] = valid_ratio
+ r = float(w) / float(h)
+ data['real_ratio'] = max(1, round(r))
+ return data
+
+
+class LongResize(object):
+
+ def __init__(self,
+ base_shape=[[64, 64], [96, 48], [112, 40], [128, 32]],
+ max_ratio=12,
+ base_h=32,
+ padding_rand=False,
+ padding_bi=False,
+ padding=True,
+ **kwargs):
+ self.base_shape = base_shape
+ self.max_ratio = max_ratio
+ self.base_h = base_h
+ self.padding = padding
+ self.padding_rand = padding_rand
+ self.padding_bi = padding_bi
+
+ def __call__(self, data):
+ data = resize_norm_img_long(
+ data,
+ self.base_shape,
+ self.max_ratio,
+ self.base_h,
+ self.padding,
+ self.padding_rand,
+ self.padding_bi,
+ )
+ return data
+
+
+class SliceResize(object):
+
+ def __init__(self, image_shape, padding=True, max_ratio=12, **kwargs):
+ self.image_shape = image_shape
+ self.padding = padding
+ self.max_ratio = max_ratio
+
+ def __call__(self, data):
+ img = data['image']
+ h, w = img.shape[:2]
+ w_bi = w // 2
+ img_list = [
+ img[:, :w_bi, :], img[:, w_bi:2 * w_bi, :],
+ img[:, w_bi // 2:(w_bi // 2) + w_bi, :]
+ ]
+ img_reshape = []
+ for img_s in img_list:
+ norm_img, valid_ratio = resize_norm_img_slice(
+ img_s, self.image_shape, max_ratio=self.max_ratio)
+ img_reshape.append(norm_img[None, :, :, :])
+ data['image'] = np.concatenate(img_reshape, 0)
+ data['valid_ratio'] = valid_ratio
+ return data
+
+
+class SliceTVResize(object):
+
+ def __init__(self,
+ image_shape,
+ padding=True,
+ base_shape=[[64, 64], [96, 48], [112, 40], [128, 32]],
+ max_ratio=12,
+ base_h=32,
+ **kwargs):
+ self.image_shape = image_shape
+ self.padding = padding
+ self.max_ratio = max_ratio
+ self.base_h = base_h
+ self.interpolation = T.InterpolationMode.BICUBIC
+ transforms = []
+ transforms.extend([
+ T.ToTensor(),
+ T.Normalize(0.5, 0.5),
+ ])
+ self.transforms = T.Compose(transforms)
+
+ def __call__(self, data):
+ img = data['image']
+ w, h = img.size
+ w_ratio = ((w // h) // 2) * 2
+ w_ratio = max(6, w_ratio)
+ img = F.resize(img, (self.base_h, self.base_h * w_ratio),
+ interpolation=self.interpolation)
+ img = self.transforms(img)
+ img_list = []
+ for i in range(0, w_ratio // 2 - 1):
+ img_list.append(img[None, :, :,
+ i * 2 * self.base_h:(i * 2 + 4) * self.base_h])
+ data['image'] = torch.concat(img_list, 0)
+ data['valid_ratio'] = float(w_ratio) / w
+ return data
+
+
+class RecTVResizeRatio(object):
+
+ def __init__(self,
+ image_shape=[32, 128],
+ padding=True,
+ base_shape=[[64, 64], [96, 48], [112, 40], [128, 32]],
+ max_ratio=12,
+ base_h=32,
+ **kwargs):
+ self.padding = padding
+ self.image_shape = image_shape
+ self.max_ratio = max_ratio
+ self.base_shape = base_shape
+ self.base_h = base_h
+ self.interpolation = T.InterpolationMode.BICUBIC
+ transforms = []
+ transforms.extend([
+ T.ToTensor(),
+ T.Normalize(0.5, 0.5),
+ ])
+ self.transforms = T.Compose(transforms)
+
+ def __call__(self, data):
+ img = data['image']
+ imgH, imgW = self.image_shape
+ w, h = img.size
+ gen_ratio = round(float(w) / float(h))
+ ratio_resize = 1 if gen_ratio == 0 else gen_ratio
+ ratio_resize = min(ratio_resize, self.max_ratio)
+ imgW, imgH = self.base_shape[ratio_resize -
+ 1] if ratio_resize <= 4 else [
+ self.base_h *
+ ratio_resize, self.base_h
+ ]
+ if not self.padding:
+ resized_w = imgW
+ else:
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = F.resize(img, (imgH, resized_w),
+ interpolation=self.interpolation)
+ img = self.transforms(resized_image)
+ if resized_w < imgW:
+ img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.)
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ data['image'] = img
+ data['valid_ratio'] = valid_ratio
+ return data
+
+
+class RecDynamicResize(object):
+
+ def __init__(self, image_shape=[32, 128], padding=True, **kwargs):
+ self.padding = padding
+ self.image_shape = image_shape
+ self.max_ratio = image_shape[1] * 1.0 / image_shape[0]
+
+ def __call__(self, data):
+ img = data['image']
+ imgH, imgW = self.image_shape
+ h, w, imgC = img.shape
+ ratio = w / float(h)
+ max_wh_ratio = max(ratio, self.max_ratio)
+ imgW = int(imgH * max_wh_ratio)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ resized_image = resized_image.astype('float32')
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+ padding_im[:, :, 0:resized_w] = resized_image
+ data['image'] = padding_im
+ return data
+
+
+def resize_norm_img_slice(
+ img,
+ image_shape,
+ base_shape=[[64, 64], [96, 48], [112, 40], [128, 32]],
+ max_ratio=12,
+ base_h=32,
+ padding=True,
+):
+ imgC, imgH, imgW = image_shape
+ h = img.shape[0]
+ w = img.shape[1]
+ gen_ratio = round(float(w) / float(h))
+ ratio_resize = 1 if gen_ratio == 0 else gen_ratio
+ ratio_resize = min(ratio_resize, max_ratio)
+ imgW, imgH = base_shape[ratio_resize - 1] if ratio_resize <= 4 else [
+ base_h * ratio_resize, base_h
+ ]
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+ if not padding:
+ resized_image = cv2.resize(img, (imgW, imgH))
+ resized_w = imgW
+ else:
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio * (random.random() + 0.5)))
+ resized_w = min(imgW, resized_w)
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+ padding_im[:, :, :resized_w] = resized_image
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ return padding_im, valid_ratio
+
+
+def resize_norm_img(img,
+ image_shape,
+ padding=True,
+ interpolation=cv2.INTER_LINEAR):
+ imgC, imgH, imgW = image_shape
+ h = img.shape[0]
+ w = img.shape[1]
+ if not padding:
+ resized_image = cv2.resize(img, (imgW, imgH),
+ interpolation=interpolation)
+ resized_w = imgW
+ else:
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ resized_image = resized_image.astype('float32')
+ if image_shape[0] == 1:
+ resized_image = resized_image / 255
+ resized_image = resized_image[np.newaxis, :]
+ else:
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+ padding_im[:, :, 0:resized_w] = resized_image
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ return padding_im, valid_ratio
+
+
+def resize_norm_img_long(
+ data,
+ base_shape=[[64, 64], [96, 48], [112, 40], [128, 32]],
+ max_ratio=12,
+ base_h=32,
+ padding=True,
+ padding_rand=False,
+ padding_bi=False,
+):
+ img = data['image']
+ h = img.shape[0]
+ w = img.shape[1]
+ gen_ratio = data.get('gen_ratio', 0)
+ if gen_ratio == 0:
+ ratio = w / float(h)
+ gen_ratio = round(ratio) if ratio > 0.5 else 1
+ gen_ratio = min(data['gen_ratio'], max_ratio)
+ if padding_rand and random.random() < 0.5:
+ padding = False if padding else True
+ imgW, imgH = base_shape[gen_ratio -
+ 1] if gen_ratio <= len(base_shape) else [
+ base_h * gen_ratio, base_h
+ ]
+ if not padding:
+ resized_image = cv2.resize(img, (imgW, imgH),
+ interpolation=cv2.INTER_LINEAR)
+ resized_w = imgW
+ else:
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio * (random.random() + 0.5)))
+ resized_w = min(imgW, resized_w)
+
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ resized_image = resized_image.astype('float32')
+
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ padding_im = np.zeros((3, imgH, imgW), dtype=np.float32)
+ if padding_bi and random.random() < 0.5:
+ padding_im[:, :, -resized_w:] = resized_image
+ else:
+ padding_im[:, :, :resized_w] = resized_image
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ data['image'] = padding_im
+ data['valid_ratio'] = valid_ratio
+ data['gen_ratio'] = imgW // imgH
+ data['real_ratio'] = w // h
+ return data
+
+
+class VisionLANResize(object):
+
+ def __init__(self, image_shape, **kwargs):
+ self.image_shape = image_shape
+
+ def __call__(self, data):
+ img = data['image']
+
+ imgC, imgH, imgW = self.image_shape
+ resized_image = cv2.resize(img, (imgW, imgH))
+ resized_image = resized_image.astype('float32')
+ if imgC == 1:
+ resized_image = resized_image / 255
+ norm_img = resized_image[np.newaxis, :]
+ else:
+ norm_img = resized_image.transpose((2, 0, 1)) / 255
+
+ data['image'] = norm_img
+ data['valid_ratio'] = 1.0
+ return data
+
+
+class RobustScannerRecResizeImg(object):
+
+ def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
+ self.image_shape = image_shape
+ self.width_downsample_ratio = width_downsample_ratio
+
+ def __call__(self, data):
+ img = data['image']
+ norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
+ img, self.image_shape, self.width_downsample_ratio)
+ data['image'] = norm_img
+ data['resized_shape'] = resize_shape
+ data['pad_shape'] = pad_shape
+ data['valid_ratio'] = valid_ratio
+ return data
+
+
+def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
+ imgC, imgH, imgW_min, imgW_max = image_shape
+ h = img.shape[0]
+ w = img.shape[1]
+ valid_ratio = 1.0
+ # make sure new_width is an integral multiple of width_divisor.
+ width_divisor = int(1 / width_downsample_ratio)
+ # resize
+ ratio = w / float(h)
+ resize_w = math.ceil(imgH * ratio)
+ if resize_w % width_divisor != 0:
+ resize_w = round(resize_w / width_divisor) * width_divisor
+ if imgW_min is not None:
+ resize_w = max(imgW_min, resize_w)
+ if imgW_max is not None:
+ valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
+ resize_w = min(imgW_max, resize_w)
+ resized_image = cv2.resize(img, (resize_w, imgH))
+ resized_image = resized_image.astype('float32')
+ # norm
+ if image_shape[0] == 1:
+ resized_image = resized_image / 255
+ resized_image = resized_image[np.newaxis, :]
+ else:
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ resize_shape = resized_image.shape
+ padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
+ padding_im[:, :, 0:resize_w] = resized_image
+ pad_shape = padding_im.shape
+
+ return padding_im, resize_shape, pad_shape, valid_ratio
+
+
+class SRNRecResizeImg(object):
+
+ def __init__(self, image_shape, **kwargs):
+ self.image_shape = image_shape
+
+ def __call__(self, data):
+ img = data['image']
+ norm_img = resize_norm_img_srn(img, self.image_shape)
+ data['image'] = norm_img
+
+ return data
+
+
+def resize_norm_img_srn(img, image_shape):
+ imgC, imgH, imgW = image_shape
+
+ img_black = np.zeros((imgH, imgW))
+ im_hei = img.shape[0]
+ im_wid = img.shape[1]
+
+ if im_wid <= im_hei * 1:
+ img_new = cv2.resize(img, (imgH * 1, imgH))
+ elif im_wid <= im_hei * 2:
+ img_new = cv2.resize(img, (imgH * 2, imgH))
+ elif im_wid <= im_hei * 3:
+ img_new = cv2.resize(img, (imgH * 3, imgH))
+ else:
+ img_new = cv2.resize(img, (imgW, imgH))
+
+ img_np = np.asarray(img_new)
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
+ img_black[:, 0:img_np.shape[1]] = img_np
+ img_black = img_black[:, :, np.newaxis]
+
+ row, col, c = img_black.shape
+ c = 1
+
+ return np.reshape(img_black, (c, row, col)).astype(np.float32)
diff --git a/openrec/preprocess/smtr_label_encode.py b/openrec/preprocess/smtr_label_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbd2a7014fe4e56d45eb03ea24b4e5a749b855db
--- /dev/null
+++ b/openrec/preprocess/smtr_label_encode.py
@@ -0,0 +1,125 @@
+import copy
+import random
+
+import numpy as np
+
+from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
+
+
+class SMTRLabelEncode(BaseRecLabelEncode):
+ """Convert between text-label and text-index."""
+
+ BOS = ''
+ EOS = ''
+ IN_F = '' # ignore
+ IN_B = '' # ignore
+ PAD = ''
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ sub_str_len=5,
+ **kwargs):
+
+ super(SMTRLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ use_space_char)
+ self.substr_len = sub_str_len
+ self.rang_subs = [i for i in range(1, self.substr_len + 1)]
+ self.idx_char = [i for i in range(1, self.num_character - 5)]
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) > self.max_text_len:
+ return None
+
+ data['length'] = np.array(len(text))
+ text_in = [self.dict[self.IN_F]] * (self.substr_len) + text + [
+ self.dict[self.IN_B]
+ ] * (self.substr_len)
+
+ sub_string_list_pre = []
+ next_label_pre = []
+ sub_string_list = []
+ next_label = []
+ for i in range(self.substr_len, len(text_in) - self.substr_len):
+
+ sub_string_list.append(text_in[i - self.substr_len:i])
+ next_label.append(text_in[i])
+
+ if self.substr_len - i == 0:
+ sub_string_list_pre.append(text_in[-i:])
+ else:
+ sub_string_list_pre.append(text_in[-i:self.substr_len - i])
+
+ next_label_pre.append(text_in[-(i + 1)])
+
+ sub_string_list.append(
+ [self.dict[self.IN_F]] *
+ (self.substr_len - len(text[-self.substr_len:])) +
+ text[-self.substr_len:])
+ next_label.append(self.dict[self.EOS])
+ sub_string_list_pre.append(
+ text[:self.substr_len] + [self.dict[self.IN_B]] *
+ (self.substr_len - len(text[:self.substr_len])))
+ next_label_pre.append(self.dict[self.EOS])
+
+ for sstr, l in zip(sub_string_list[self.substr_len:],
+ next_label[self.substr_len:]):
+
+ id_shu = np.random.choice(self.rang_subs, 2)
+
+ sstr1 = copy.deepcopy(sstr)
+ sstr1[id_shu[0] - 1] = random.randint(1, self.num_character - 5)
+ if sstr1 not in sub_string_list:
+ sub_string_list.append(sstr1)
+ next_label.append(l)
+
+ sstr[id_shu[1] - 1] = random.randint(1, self.num_character - 5)
+
+ for sstr, l in zip(sub_string_list_pre[self.substr_len:],
+ next_label_pre[self.substr_len:]):
+
+ id_shu = np.random.choice(self.rang_subs, 2)
+
+ sstr1 = copy.deepcopy(sstr)
+ sstr1[id_shu[0] - 1] = random.randint(1, self.num_character - 5)
+ if sstr1 not in sub_string_list_pre:
+ sub_string_list_pre.append(sstr1)
+ next_label_pre.append(l)
+ sstr[id_shu[1] - 1] = random.randint(1, self.num_character - 5)
+
+ data['length_subs'] = np.array(len(sub_string_list))
+ sub_string_list = sub_string_list + [
+ [self.dict[self.PAD]] * self.substr_len
+ ] * ((self.max_text_len * 2) + 2 - len(sub_string_list))
+ next_label = next_label + [self.dict[self.PAD]] * (
+ (self.max_text_len * 2) + 2 - len(next_label))
+ data['label_subs'] = np.array(sub_string_list)
+ data['label_next'] = np.array(next_label)
+
+ data['length_subs_pre'] = np.array(len(sub_string_list_pre))
+ sub_string_list_pre = sub_string_list_pre + [
+ [self.dict[self.PAD]] * self.substr_len
+ ] * ((self.max_text_len * 2) + 2 - len(sub_string_list_pre))
+ next_label_pre = next_label_pre + [self.dict[self.PAD]] * (
+ (self.max_text_len * 2) + 2 - len(next_label_pre))
+ data['label_subs_pre'] = np.array(sub_string_list_pre)
+ data['label_next_pre'] = np.array(next_label_pre)
+
+ text = [self.dict[self.BOS]] + text + [self.dict[self.EOS]]
+ text = text + [self.dict[self.PAD]
+ ] * (self.max_text_len + 2 - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = [self.EOS] + dict_character + [
+ self.BOS, self.IN_F, self.IN_B, self.PAD
+ ]
+ self.num_character = len(dict_character)
+ return dict_character
diff --git a/openrec/preprocess/srn_label_encode.py b/openrec/preprocess/srn_label_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..b045307da7114d6c286774245ea6b05da1c4bf71
--- /dev/null
+++ b/openrec/preprocess/srn_label_encode.py
@@ -0,0 +1,37 @@
+import numpy as np
+
+from .ce_label_encode import BaseRecLabelEncode
+
+
+class SRNLabelEncode(BaseRecLabelEncode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(SRNLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ use_space_char)
+
+ def add_special_char(self, dict_character):
+ dict_character = dict_character + ['', '']
+ self.start_idx = len(dict_character) - 2
+ self.end_idx = len(dict_character) - 1
+ return dict_character
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) > self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ text = text + [self.end_idx] * (self.max_text_len - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def get_ignored_tokens(self):
+ return [self.start_idx, self.end_idx]
diff --git a/openrec/preprocess/visionlan_label_encode.py b/openrec/preprocess/visionlan_label_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..b69920dbacc111040231d07265ca69fffa2bd370
--- /dev/null
+++ b/openrec/preprocess/visionlan_label_encode.py
@@ -0,0 +1,67 @@
+from random import sample
+
+import numpy as np
+
+from .ctc_label_encode import BaseRecLabelEncode
+
+
+class VisionLANLabelEncode(BaseRecLabelEncode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(VisionLANLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ use_space_char)
+ self.dict = {}
+ for i, char in enumerate(self.character):
+ self.dict[char] = i
+
+ def __call__(self, data):
+ text = data['label'] # original string
+ # generate occluded text
+ len_str = len(text)
+ if len_str <= 0:
+ return None
+ change_num = 1
+ order = list(range(len_str))
+ change_id = sample(order, change_num)[0]
+ label_sub = text[change_id]
+ if change_id == (len_str - 1):
+ label_res = text[:change_id]
+ elif change_id == 0:
+ label_res = text[1:]
+ else:
+ label_res = text[:change_id] + text[change_id + 1:]
+
+ data['label_res'] = label_res # remaining string
+ data['label_sub'] = label_sub # occluded character
+ data['label_id'] = change_id # character index
+ # encode label
+ text = self.encode(text)
+ if text is None:
+ return None
+ text = [i + 1 for i in text]
+ data['length'] = np.array(len(text))
+ text = text + [0] * (self.max_text_len + 1 - len(text))
+ data['label'] = np.array(text)
+ label_res = self.encode(label_res)
+ label_sub = self.encode(label_sub)
+ if label_res is None:
+ label_res = []
+ else:
+ label_res = [i + 1 for i in label_res]
+ if label_sub is None:
+ label_sub = []
+ else:
+ label_sub = [i + 1 for i in label_sub]
+ data['length_res'] = np.array(len(label_res))
+ data['length_sub'] = np.array(len(label_sub))
+ label_res = label_res + [0] * (self.max_text_len - len(label_res))
+ label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
+ data['label_res'] = np.array(label_res)
+ data['label_sub'] = np.array(label_sub)
+ return data
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c2ef64f6b699faf479c4b5fd39e18919e943da90
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,9 @@
+imgaug
+lmdb
+numpy
+opencv-python<=4.6.0.66
+pyclipper
+pyyaml
+rapidfuzz
+tqdm
+gradio==4.20.0
\ No newline at end of file
diff --git a/tools/__pycache__/infer_det.cpython-38.pyc b/tools/__pycache__/infer_det.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d66431cd9589a661fddb8051ef7dc29b504f1a2
Binary files /dev/null and b/tools/__pycache__/infer_det.cpython-38.pyc differ
diff --git a/tools/__pycache__/infer_e2e.cpython-38.pyc b/tools/__pycache__/infer_e2e.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed90183c5afc9b31038d26cf74c3dd6dcb99b00d
Binary files /dev/null and b/tools/__pycache__/infer_e2e.cpython-38.pyc differ
diff --git a/tools/__pycache__/infer_rec.cpython-38.pyc b/tools/__pycache__/infer_rec.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..458c4fbd147da50eb317e5fa7829f4c5b255be33
Binary files /dev/null and b/tools/__pycache__/infer_rec.cpython-38.pyc differ
diff --git a/tools/__pycache__/utility.cpython-38.pyc b/tools/__pycache__/utility.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a93ca46b6815ec0ed22abc98e6916d384ae09090
Binary files /dev/null and b/tools/__pycache__/utility.cpython-38.pyc differ
diff --git a/tools/create_lmdb_dataset.py b/tools/create_lmdb_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a96aa92e560ce760de36cd1ee922efe7f5f20fa5
--- /dev/null
+++ b/tools/create_lmdb_dataset.py
@@ -0,0 +1,118 @@
+import os
+import lmdb
+import cv2
+from tqdm import tqdm
+import numpy as np
+import io
+from PIL import Image
+""" a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """
+
+
+def get_datalist(data_dir, data_path, max_len):
+ """
+ 获取训练和验证的数据list
+ :param data_dir: 数据集根目录
+ :param data_path: 训练的dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’
+ :return:
+ """
+ train_data = []
+ if isinstance(data_path, list):
+ for p in data_path:
+ train_data.extend(get_datalist(data_dir, p, max_len))
+ else:
+ with open(data_path, 'r', encoding='utf-8') as f:
+ for line in tqdm(f.readlines(),
+ desc=f'load data from {data_path}'):
+ line = (line.strip('\n').replace('.jpg ', '.jpg\t').replace(
+ '.png ', '.png\t').split('\t'))
+ if len(line) > 1:
+ img_path = os.path.join(data_dir, line[0].strip(' '))
+ label = line[1]
+ if len(label) > max_len:
+ continue
+ if os.path.exists(
+ img_path) and os.path.getsize(img_path) > 0:
+ train_data.append([str(img_path), label])
+ return train_data
+
+
+def checkImageIsValid(imageBin):
+ if imageBin is None:
+ return False
+ imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
+ img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
+ imgH, imgW = img.shape[0], img.shape[1]
+ if imgH * imgW == 0:
+ return False
+ return True
+
+
+def writeCache(env, cache):
+ with env.begin(write=True) as txn:
+ for k, v in cache.items():
+ txn.put(k, v)
+
+
+def createDataset(data_list, outputPath, checkValid=True):
+ """
+ Create LMDB dataset for training and evaluation.
+ ARGS:
+ inputPath : input folder path where starts imagePath
+ outputPath : LMDB output path
+ gtFile : list of image path and label
+ checkValid : if true, check the validity of every image
+ """
+ os.makedirs(outputPath, exist_ok=True)
+ env = lmdb.open(outputPath, map_size=1099511627776)
+ cache = {}
+ cnt = 1
+ for imagePath, label in tqdm(data_list,
+ desc=f'make dataset, save to {outputPath}'):
+ with open(imagePath, 'rb') as f:
+ imageBin = f.read()
+ buf = io.BytesIO(imageBin)
+ w, h = Image.open(buf).size
+ if checkValid:
+ try:
+ if not checkImageIsValid(imageBin):
+ print('%s is not a valid image' % imagePath)
+ continue
+ except:
+ continue
+
+ imageKey = 'image-%09d'.encode() % cnt
+ labelKey = 'label-%09d'.encode() % cnt
+ whKey = 'wh-%09d'.encode() % cnt
+ cache[imageKey] = imageBin
+ cache[labelKey] = label.encode()
+ cache[whKey] = (str(w) + '_' + str(h)).encode()
+
+ if cnt % 1000 == 0:
+ writeCache(env, cache)
+ cache = {}
+ cnt += 1
+ nSamples = cnt - 1
+ cache['num-samples'.encode()] = str(nSamples).encode()
+ writeCache(env, cache)
+ print('Created dataset with %d samples' % nSamples)
+
+
+if __name__ == '__main__':
+ data_dir = './Union14M-L/'
+ label_file_list = [
+ './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_challenging.jsonl.txt',
+ './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_easy.jsonl.txt',
+ './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_hard.jsonl.txt',
+ './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_medium.jsonl.txt',
+ './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_normal.jsonl.txt'
+ ]
+ save_path_root = './Union14M-L-LMDB-Filtered/'
+
+ for data_list in label_file_list:
+ save_path = save_path_root + data_list.split('/')[-1].split(
+ '.')[0] + '/'
+ os.makedirs(save_path, exist_ok=True)
+ print(save_path)
+ train_data_list = get_datalist(data_dir, data_list, 800)
+
+ createDataset(train_data_list, save_path)
diff --git a/tools/data/__init__.py b/tools/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad1a20331154133be4575cfd6104d02714e17198
--- /dev/null
+++ b/tools/data/__init__.py
@@ -0,0 +1,94 @@
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+import copy
+
+from torch.utils.data import DataLoader, DistributedSampler
+
+from tools.data.lmdb_dataset import LMDBDataSet
+from tools.data.lmdb_dataset_test import LMDBDataSetTest
+from tools.data.multi_scale_sampler import MultiScaleSampler
+from tools.data.ratio_dataset import RatioDataSet
+from tools.data.ratio_dataset_test import RatioDataSetTest
+from tools.data.ratio_dataset_tvresize_test import RatioDataSetTVResizeTest
+from tools.data.ratio_dataset_tvresize import RatioDataSetTVResize
+from tools.data.ratio_sampler import RatioSampler
+from tools.data.simple_dataset import MultiScaleDataSet, SimpleDataSet
+from tools.data.strlmdb_dataset import STRLMDBDataSet
+
+__all__ = [
+ 'build_dataloader',
+ 'transform',
+ 'create_operators',
+]
+
+
+def build_dataloader(config, mode, logger, seed=None, epoch=3):
+ config = copy.deepcopy(config)
+
+ support_dict = [
+ 'SimpleDataSet', 'LMDBDataSet', 'MultiScaleDataSet', 'STRLMDBDataSet',
+ 'LMDBDataSetTest', 'RatioDataSet', 'RatioDataSetTest',
+ 'RatioDataSetTVResize', 'RatioDataSetTVResizeTest'
+ ]
+ module_name = config[mode]['dataset']['name']
+ assert module_name in support_dict, Exception(
+ 'DataSet only support {}/{}'.format(support_dict, module_name))
+ assert mode in ['Train', 'Eval',
+ 'Test'], 'Mode should be Train, Eval or Test.'
+
+ dataset = eval(module_name)(config, mode, logger, seed, epoch=epoch)
+ loader_config = config[mode]['loader']
+ batch_size = loader_config['batch_size_per_card']
+ drop_last = loader_config['drop_last']
+ shuffle = loader_config['shuffle']
+ num_workers = loader_config['num_workers']
+ if 'pin_memory' in loader_config.keys():
+ pin_memory = loader_config['use_shared_memory']
+ else:
+ pin_memory = False
+
+ sampler = None
+ batch_sampler = None
+ if 'sampler' in config[mode]:
+ config_sampler = config[mode]['sampler']
+ sampler_name = config_sampler.pop('name')
+ batch_sampler = eval(sampler_name)(dataset, **config_sampler)
+ elif config['Global']['distributed'] and mode == 'Train':
+ sampler = DistributedSampler(dataset=dataset, shuffle=shuffle)
+
+ if 'collate_fn' in loader_config:
+ from . import collate_fn
+
+ collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
+ else:
+ collate_fn = None
+ if batch_sampler is None:
+ data_loader = DataLoader(
+ dataset=dataset,
+ sampler=sampler,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ collate_fn=collate_fn,
+ batch_size=batch_size,
+ drop_last=drop_last,
+ )
+ else:
+ data_loader = DataLoader(
+ dataset=dataset,
+ batch_sampler=batch_sampler,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ collate_fn=collate_fn,
+ )
+ if len(data_loader) == 0:
+ logger.error(
+ f'No Images in {mode.lower()} dataloader, please ensure\n'
+ '\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n'
+ '\t2. The annotation file and path in the configuration file are provided normally.\n'
+ '\t3. The BatchSize is large than images.')
+ sys.exit()
+ return data_loader
diff --git a/tools/data/__pycache__/__init__.cpython-38.pyc b/tools/data/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..380783dd80ecf707ddc8f84bf27a7d25accd879f
Binary files /dev/null and b/tools/data/__pycache__/__init__.cpython-38.pyc differ
diff --git a/tools/data/__pycache__/lmdb_dataset.cpython-38.pyc b/tools/data/__pycache__/lmdb_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8465ca39fede15f203b8f6adaa2b4fee31fc585c
Binary files /dev/null and b/tools/data/__pycache__/lmdb_dataset.cpython-38.pyc differ
diff --git a/tools/data/__pycache__/lmdb_dataset_test.cpython-38.pyc b/tools/data/__pycache__/lmdb_dataset_test.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..186d6e238d4bdf756eb04a15128d3c5b3615c13a
Binary files /dev/null and b/tools/data/__pycache__/lmdb_dataset_test.cpython-38.pyc differ
diff --git a/tools/data/__pycache__/multi_scale_sampler.cpython-38.pyc b/tools/data/__pycache__/multi_scale_sampler.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d6feedda32caa53766e1236d307edde56c75431
Binary files /dev/null and b/tools/data/__pycache__/multi_scale_sampler.cpython-38.pyc differ
diff --git a/tools/data/__pycache__/ratio_dataset.cpython-38.pyc b/tools/data/__pycache__/ratio_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24b1de3f347daf006a2dde93313d9efccbaf8867
Binary files /dev/null and b/tools/data/__pycache__/ratio_dataset.cpython-38.pyc differ
diff --git a/tools/data/__pycache__/ratio_dataset_test.cpython-38.pyc b/tools/data/__pycache__/ratio_dataset_test.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7b49e734a0cc8237004d8d0d2231c611fa5f732
Binary files /dev/null and b/tools/data/__pycache__/ratio_dataset_test.cpython-38.pyc differ
diff --git a/tools/data/__pycache__/ratio_dataset_tvresize.cpython-38.pyc b/tools/data/__pycache__/ratio_dataset_tvresize.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fbdfe71973fe1dada1b4bf05287c676a8e7ea752
Binary files /dev/null and b/tools/data/__pycache__/ratio_dataset_tvresize.cpython-38.pyc differ
diff --git a/tools/data/__pycache__/ratio_dataset_tvresize_test.cpython-38.pyc b/tools/data/__pycache__/ratio_dataset_tvresize_test.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13e8c0daa4e6f6768eff2ced50830bbf399d72f6
Binary files /dev/null and b/tools/data/__pycache__/ratio_dataset_tvresize_test.cpython-38.pyc differ
diff --git a/tools/data/__pycache__/ratio_sampler.cpython-38.pyc b/tools/data/__pycache__/ratio_sampler.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2923762d120e1b01e2b854183cda79ca35163cdd
Binary files /dev/null and b/tools/data/__pycache__/ratio_sampler.cpython-38.pyc differ
diff --git a/tools/data/__pycache__/simple_dataset.cpython-38.pyc b/tools/data/__pycache__/simple_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5204a291f2f9475dafa1868a818510913de65f19
Binary files /dev/null and b/tools/data/__pycache__/simple_dataset.cpython-38.pyc differ
diff --git a/tools/data/__pycache__/strlmdb_dataset.cpython-38.pyc b/tools/data/__pycache__/strlmdb_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e032d2e55490611ac22db23e7c9a4cb46ee94ebb
Binary files /dev/null and b/tools/data/__pycache__/strlmdb_dataset.cpython-38.pyc differ
diff --git a/tools/data/collate_fn.py b/tools/data/collate_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..4453005702d3cab061e9ba0c6589a0a7560b9f7a
--- /dev/null
+++ b/tools/data/collate_fn.py
@@ -0,0 +1,100 @@
+import numbers
+from collections import defaultdict
+
+import numpy as np
+import torch
+
+
+class DictCollator(object):
+ """data batch."""
+
+ def __call__(self, batch):
+ data_dict = defaultdict(list)
+ to_tensor_keys = []
+ for sample in batch:
+ for k, v in sample.items():
+ if isinstance(v, (np.ndarray, torch.Tensor, numbers.Number)):
+ if k not in to_tensor_keys:
+ to_tensor_keys.append(k)
+ data_dict[k].append(v)
+ for k in to_tensor_keys:
+ data_dict[k] = torch.from_numpy(data_dict[k])
+ return data_dict
+
+
+class ListCollator(object):
+ """data batch."""
+
+ def __call__(self, batch):
+ data_dict = defaultdict(list)
+ to_tensor_idxs = []
+ for sample in batch:
+ for idx, v in enumerate(sample):
+ if isinstance(v, (np.ndarray, torch.Tensor, numbers.Number)):
+ if idx not in to_tensor_idxs:
+ to_tensor_idxs.append(idx)
+ data_dict[idx].append(v)
+ for idx in to_tensor_idxs:
+ data_dict[idx] = torch.from_numpy(data_dict[idx])
+ return list(data_dict.values())
+
+
+class SSLRotateCollate(object):
+ """
+ bach: [
+ [(4*3xH*W), (4,)]
+ [(4*3xH*W), (4,)]
+ ...
+ ]
+ """
+
+ def __call__(self, batch):
+ output = [np.concatenate(d, axis=0) for d in zip(*batch)]
+ return output
+
+
+class DyMaskCollator(object):
+ """
+ batch: [
+ image [batch_size, channel, maxHinbatch, maxWinbatch]
+ image_mask [batch_size, channel, maxHinbatch, maxWinbatch]
+ label [batch_size, maxLabelLen]
+ label_mask [batch_size, maxLabelLen]
+ ...
+ ]
+ """
+
+ def __call__(self, batch):
+ max_width, max_height, max_length = 0, 0, 0
+ bs, channel = len(batch), batch[0][0].shape[0]
+ proper_items = []
+ for item in batch:
+ if item[0].shape[1] * max_width > 1600 * 320 or item[0].shape[
+ 2] * max_height > 1600 * 320:
+ continue
+ max_height = item[0].shape[
+ 1] if item[0].shape[1] > max_height else max_height
+ max_width = item[0].shape[
+ 2] if item[0].shape[2] > max_width else max_width
+ max_length = len(
+ item[1]) if len(item[1]) > max_length else max_length
+ proper_items.append(item)
+
+ images, image_masks = np.zeros(
+ (len(proper_items), channel, max_height, max_width),
+ dtype='float32'), np.zeros(
+ (len(proper_items), 1, max_height, max_width), dtype='float32')
+ labels, label_masks = np.zeros((len(proper_items), max_length),
+ dtype='int64'), np.zeros(
+ (len(proper_items), max_length),
+ dtype='int64')
+
+ for i in range(len(proper_items)):
+ _, h, w = proper_items[i][0].shape
+ images[i][:, :h, :w] = proper_items[i][0]
+ image_masks[i][:, :h, :w] = 1
+ l = len(proper_items[i][1])
+ labels[i][:l] = proper_items[i][1]
+ label_masks[i][:l] = 1
+
+ return images, image_masks, labels, label_masks
diff --git a/tools/data/lmdb_dataset.py b/tools/data/lmdb_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..be17c505221a4e7fd830ac33ff2aca3dd44135cd
--- /dev/null
+++ b/tools/data/lmdb_dataset.py
@@ -0,0 +1,142 @@
+import os
+
+import cv2
+import lmdb
+import numpy as np
+from torch.utils.data import Dataset
+
+from openrec.preprocess import create_operators, transform
+
+
+class LMDBDataSet(Dataset):
+
+ def __init__(self, config, mode, logger, seed=None, epoch=1):
+ super(LMDBDataSet, self).__init__()
+
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ loader_config = config[mode]['loader']
+ loader_config['batch_size_per_card']
+ data_dir = dataset_config['data_dir']
+ self.do_shuffle = loader_config['shuffle']
+
+ self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
+ logger.info(f'Initialize indexs of datasets: {data_dir}')
+ self.data_idx_order_list = self.dataset_traversal()
+ if self.do_shuffle:
+ np.random.shuffle(self.data_idx_order_list)
+ self.ops = create_operators(dataset_config['transforms'],
+ global_config)
+ self.ext_op_transform_idx = dataset_config.get('ext_op_transform_idx',
+ 1)
+
+ ratio_list = dataset_config.get('ratio_list', [1.0])
+ self.need_reset = True in [x < 1 for x in ratio_list]
+
+ def load_hierarchical_lmdb_dataset(self, data_dir):
+ lmdb_sets = {}
+ dataset_idx = 0
+ for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
+ if not dirnames:
+ env = lmdb.open(
+ dirpath,
+ max_readers=32,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False,
+ )
+ txn = env.begin(write=False)
+ num_samples = int(txn.get('num-samples'.encode()))
+ lmdb_sets[dataset_idx] = {
+ 'dirpath': dirpath,
+ 'env': env,
+ 'txn': txn,
+ 'num_samples': num_samples,
+ }
+ dataset_idx += 1
+ return lmdb_sets
+
+ def dataset_traversal(self):
+ lmdb_num = len(self.lmdb_sets)
+ total_sample_num = 0
+ for lno in range(lmdb_num):
+ total_sample_num += self.lmdb_sets[lno]['num_samples']
+ data_idx_order_list = np.zeros((total_sample_num, 2))
+ beg_idx = 0
+ for lno in range(lmdb_num):
+ tmp_sample_num = self.lmdb_sets[lno]['num_samples']
+ end_idx = beg_idx + tmp_sample_num
+ data_idx_order_list[beg_idx:end_idx, 0] = lno
+ data_idx_order_list[beg_idx:end_idx,
+ 1] = list(range(tmp_sample_num))
+ data_idx_order_list[beg_idx:end_idx, 1] += 1
+ beg_idx = beg_idx + tmp_sample_num
+ return data_idx_order_list
+
+ def get_img_data(self, value):
+ """get_img_data."""
+ if not value:
+ return None
+ imgdata = np.frombuffer(value, dtype='uint8')
+ if imgdata is None:
+ return None
+ imgori = cv2.imdecode(imgdata, 1)
+ if imgori is None:
+ return None
+ return imgori
+
+ def get_ext_data(self):
+ ext_data_num = 0
+ for op in self.ops:
+ if hasattr(op, 'ext_data_num'):
+ ext_data_num = getattr(op, 'ext_data_num')
+ break
+ load_data_ops = self.ops[:self.ext_op_transform_idx]
+ ext_data = []
+
+ while len(ext_data) < ext_data_num:
+ lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(
+ len(self))]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ sample_info = self.get_lmdb_sample_info(
+ self.lmdb_sets[lmdb_idx]['txn'], file_idx)
+ if sample_info is None:
+ continue
+ img, label = sample_info
+ data = {'image': img, 'label': label}
+ data = transform(data, load_data_ops)
+ if data is None:
+ continue
+ ext_data.append(data)
+ return ext_data
+
+ def get_lmdb_sample_info(self, txn, index):
+ label_key = 'label-%09d'.encode() % index
+ label = txn.get(label_key)
+ if label is None:
+ return None
+ label = label.decode('utf-8')
+ img_key = 'image-%09d'.encode() % index
+ imgbuf = txn.get(img_key)
+ return imgbuf, label
+
+ def __getitem__(self, idx):
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ sample_info = self.get_lmdb_sample_info(
+ self.lmdb_sets[lmdb_idx]['txn'], file_idx)
+ if sample_info is None:
+ return self.__getitem__(np.random.randint(self.__len__()))
+ img, label = sample_info
+ data = {'image': img, 'label': label}
+ data['ext_data'] = self.get_ext_data()
+ outs = transform(data, self.ops)
+ if outs is None:
+ return self.__getitem__(np.random.randint(self.__len__()))
+ return outs
+
+ def __len__(self):
+ return self.data_idx_order_list.shape[0]
diff --git a/tools/data/lmdb_dataset_test.py b/tools/data/lmdb_dataset_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fffda74b05b2db0923c1351315c673ec660710b
--- /dev/null
+++ b/tools/data/lmdb_dataset_test.py
@@ -0,0 +1,166 @@
+import io
+import re
+import unicodedata
+
+import lmdb
+from PIL import Image
+from torch.utils.data import Dataset
+
+from openrec.preprocess import create_operators, transform
+
+
+class CharsetAdapter:
+ """Transforms labels according to the target charset."""
+
+ def __init__(self, target_charset) -> None:
+ super().__init__()
+ self.lowercase_only = target_charset == target_charset.lower()
+ self.uppercase_only = target_charset == target_charset.upper()
+ self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
+
+ def __call__(self, label):
+ if self.lowercase_only:
+ label = label.lower()
+ elif self.uppercase_only:
+ label = label.upper()
+ # Remove unsupported characters
+ label = self.unsupported.sub('', label)
+ return label
+
+
+class LMDBDataSetTest(Dataset):
+ """Dataset interface to an LMDB database.
+
+ It supports both labelled and unlabelled datasets. For unlabelled datasets,
+ the image index itself is returned as the label. Unicode characters are
+ normalized by default. Case-sensitivity is inferred from the charset.
+ Labels are transformed according to the charset.
+ """
+
+ def __init__(self,
+ config,
+ mode,
+ logger,
+ seed=None,
+ epoch=1,
+ gpu_i=0,
+ max_label_len: int = 25,
+ min_image_dim: int = 0,
+ remove_whitespace: bool = True,
+ normalize_unicode: bool = True,
+ unlabelled: bool = False,
+ transform=None):
+ dataset_config = config[mode]['dataset']
+ global_config = config['Global']
+ max_label_len = global_config['max_text_length']
+ self.root = dataset_config['data_dir']
+ self._env = None
+ self.unlabelled = unlabelled
+ self.transform = transform
+ self.labels = []
+ self.filtered_index_list = []
+ self.min_image_dim = min_image_dim
+ self.filter_label = dataset_config.get('filter_label',
+ True) #'data_dir']filter_label
+ character_dict_path = global_config.get('character_dict_path', None)
+ use_space_char = global_config.get('use_space_char', False)
+ if character_dict_path is None:
+ char_test = '0123456789abcdefghijklmnopqrstuvwxyz'
+ else:
+ char_test = ''
+ with open(character_dict_path, 'rb') as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip('\n').strip('\r\n')
+ char_test += line
+ if use_space_char:
+ char_test += ' '
+ self.ops = create_operators(dataset_config['transforms'],
+ global_config)
+ self.num_samples = self._preprocess_labels(char_test,
+ remove_whitespace,
+ normalize_unicode,
+ max_label_len,
+ min_image_dim)
+
+ def __del__(self):
+ if self._env is not None:
+ self._env.close()
+ self._env = None
+
+ def _create_env(self):
+ return lmdb.open(self.root,
+ max_readers=1,
+ readonly=True,
+ create=False,
+ readahead=False,
+ meminit=False,
+ lock=False)
+
+ @property
+ def env(self):
+ if self._env is None:
+ self._env = self._create_env()
+ return self._env
+
+ def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode,
+ max_label_len, min_image_dim):
+ charset_adapter = CharsetAdapter(charset)
+ with self._create_env() as env, env.begin() as txn:
+ num_samples = int(txn.get('num-samples'.encode()))
+ if self.unlabelled:
+ return num_samples
+ for index in range(num_samples):
+ index += 1 # lmdb starts with 1
+ label_key = f'label-{index:09d}'.encode()
+ label = txn.get(label_key).decode()
+ # Normally, whitespace is removed from the labels.
+ if remove_whitespace:
+ label = ''.join(label.split())
+ # Normalize unicode composites (if any) and convert to compatible ASCII characters
+ if self.filter_label:
+ # if normalize_unicode:
+ label = unicodedata.normalize('NFKD', label).encode(
+ 'ascii', 'ignore').decode()
+ # Filter by length before removing unsupported characters. The original label might be too long.
+ if len(label) > max_label_len:
+ continue
+
+ if self.filter_label:
+ label = charset_adapter(label)
+ # We filter out samples which don't contain any supported characters
+ if not label:
+ continue
+ # Filter images that are too small.
+ if min_image_dim > 0:
+ img_key = f'image-{index:09d}'.encode()
+ img = txn.get(img_key)
+ data = {'image': img, 'label': label}
+ outs = transform(data, self.ops)
+ if outs is None:
+ continue
+ buf = io.BytesIO(img)
+ w, h = Image.open(buf).size
+ if w < self.min_image_dim or h < self.min_image_dim:
+ continue
+ self.labels.append(label)
+ self.filtered_index_list.append(index)
+ return len(self.labels)
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ if self.unlabelled:
+ label = index
+ else:
+ label = self.labels[index]
+ index = self.filtered_index_list[index]
+
+ img_key = f'image-{index:09d}'.encode()
+ with self.env.begin() as txn:
+ img = txn.get(img_key)
+ data = {'image': img, 'label': label}
+ outs = transform(data, self.ops)
+
+ return outs
diff --git a/tools/data/multi_scale_sampler.py b/tools/data/multi_scale_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..89edbd22fa0bc1e171411f84b5ad2e6daabf020a
--- /dev/null
+++ b/tools/data/multi_scale_sampler.py
@@ -0,0 +1,177 @@
+import random
+
+import numpy as np
+import torch.distributed as dist
+from torch.utils.data import Sampler
+
+
+class MultiScaleSampler(Sampler):
+
+ def __init__(
+ self,
+ data_source,
+ scales,
+ first_bs=128,
+ fix_bs=True,
+ divided_factor=[8, 16],
+ is_training=True,
+ ratio_wh=0.8,
+ max_w=480.0,
+ seed=None,
+ ):
+ """
+ multi scale samper
+ Args:
+ data_source(dataset)
+ scales(list): several scales for image resolution
+ first_bs(int): batch size for the first scale in scales
+ divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
+ is_training(boolean): mode
+ """
+ # min. and max. spatial dimensions
+ self.data_source = data_source
+ self.data_idx_order_list = np.array(data_source.data_idx_order_list)
+ self.ds_width = data_source.ds_width
+ self.seed = data_source.seed
+ if self.ds_width:
+ self.wh_ratio = data_source.wh_ratio
+ self.wh_ratio_sort = data_source.wh_ratio_sort
+ self.n_data_samples = len(self.data_source)
+ self.ratio_wh = ratio_wh
+ self.max_w = max_w
+
+ if isinstance(scales[0], list):
+ width_dims = [i[0] for i in scales]
+ height_dims = [i[1] for i in scales]
+ elif isinstance(scales[0], int):
+ width_dims = scales
+ height_dims = scales
+ base_im_w = width_dims[0]
+ base_im_h = height_dims[0]
+ base_batch_size = first_bs
+
+ # Get the GPU and node related information
+ if dist.is_initialized():
+ num_replicas = dist.get_world_size()
+ rank = dist.get_rank()
+ else:
+ num_replicas = 1
+ rank = 0
+ # adjust the total samples to avoid batch dropping
+ num_samples_per_replica = int(self.n_data_samples * 1.0 / num_replicas)
+
+ img_indices = [idx for idx in range(self.n_data_samples)]
+
+ self.shuffle = False
+ if is_training:
+ # compute the spatial dimensions and corresponding batch size
+ # ImageNet models down-sample images by a factor of 32.
+ # Ensure that width and height dimensions are multiples are multiple of 32.
+ width_dims = [
+ int((w // divided_factor[0]) * divided_factor[0])
+ for w in width_dims
+ ]
+ height_dims = [
+ int((h // divided_factor[1]) * divided_factor[1])
+ for h in height_dims
+ ]
+
+ img_batch_pairs = list()
+ base_elements = base_im_w * base_im_h * base_batch_size
+ for h, w in zip(height_dims, width_dims):
+ if fix_bs:
+ batch_size = base_batch_size
+ else:
+ batch_size = int(max(1, (base_elements / (h * w))))
+ img_batch_pairs.append((w, h, batch_size))
+ self.img_batch_pairs = img_batch_pairs
+ self.shuffle = True
+ else:
+ self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)]
+
+ self.img_indices = img_indices
+ self.n_samples_per_replica = num_samples_per_replica
+ self.epoch = 0
+ self.rank = rank
+ self.num_replicas = num_replicas
+
+ self.batch_list = []
+ self.current = 0
+ last_index = num_samples_per_replica * num_replicas
+ indices_rank_i = self.img_indices[self.rank:last_index:self.
+ num_replicas]
+ while self.current < self.n_samples_per_replica:
+ for curr_w, curr_h, curr_bsz in self.img_batch_pairs:
+ end_index = min(self.current + curr_bsz,
+ self.n_samples_per_replica)
+ batch_ids = indices_rank_i[self.current:end_index]
+ n_batch_samples = len(batch_ids)
+ if n_batch_samples != curr_bsz:
+ batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
+ self.current += curr_bsz
+
+ if len(batch_ids) > 0:
+ batch = [curr_w, curr_h, len(batch_ids)]
+ self.batch_list.append(batch)
+ random.shuffle(self.batch_list)
+ self.length = len(self.batch_list)
+ self.batchs_in_one_epoch = self.iter()
+ self.batchs_in_one_epoch_id = [
+ i for i in range(len(self.batchs_in_one_epoch))
+ ]
+
+ def __iter__(self):
+ if self.seed is None:
+ random.seed(self.epoch)
+ self.epoch += 1
+ else:
+ random.seed(self.seed)
+ random.shuffle(self.batchs_in_one_epoch_id)
+ for batch_tuple_id in self.batchs_in_one_epoch_id:
+ yield self.batchs_in_one_epoch[batch_tuple_id]
+
+ def iter(self):
+ if self.shuffle:
+ if self.seed is not None:
+ random.seed(self.seed)
+ else:
+ random.seed(self.epoch)
+ if not self.ds_width:
+ random.shuffle(self.img_indices)
+ random.shuffle(self.img_batch_pairs)
+ indices_rank_i = self.img_indices[
+ self.rank:len(self.img_indices):self.num_replicas]
+ else:
+ indices_rank_i = self.img_indices[
+ self.rank:len(self.img_indices):self.num_replicas]
+
+ start_index = 0
+ batchs_in_one_epoch = []
+ for batch_tuple in self.batch_list:
+ curr_w, curr_h, curr_bsz = batch_tuple
+ end_index = min(start_index + curr_bsz, self.n_samples_per_replica)
+ batch_ids = indices_rank_i[start_index:end_index]
+ n_batch_samples = len(batch_ids)
+ if n_batch_samples != curr_bsz:
+ batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
+ start_index += curr_bsz
+
+ if len(batch_ids) > 0:
+ if self.ds_width:
+ wh_ratio_current = self.wh_ratio[
+ self.wh_ratio_sort[batch_ids]]
+ ratio_current = wh_ratio_current.mean()
+ ratio_current = ratio_current if ratio_current * curr_h < self.max_w else self.max_w / curr_h
+ else:
+ ratio_current = None
+ batch = [(curr_w, curr_h, b_id, ratio_current)
+ for b_id in batch_ids]
+ # yield batch
+ batchs_in_one_epoch.append(batch)
+ return batchs_in_one_epoch
+
+ def set_epoch(self, epoch: int):
+ self.epoch = epoch
+
+ def __len__(self):
+ return self.length
diff --git a/tools/data/ratio_dataset.py b/tools/data/ratio_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..514ac3dfeb499d9a1fa8876d31a80c7257e74174
--- /dev/null
+++ b/tools/data/ratio_dataset.py
@@ -0,0 +1,217 @@
+import io
+import math
+import random
+import os
+import cv2
+import lmdb
+import numpy as np
+from PIL import Image
+from torch.utils.data import Dataset
+
+from openrec.preprocess import create_operators, transform
+
+
+class RatioDataSet(Dataset):
+
+ def __init__(self, config, mode, logger, seed=None, epoch=1):
+ super(RatioDataSet, self).__init__()
+ self.ds_width = config[mode]['dataset'].get('ds_width', True)
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ loader_config = config[mode]['loader']
+ max_ratio = loader_config.get('max_ratio', 10)
+ min_ratio = loader_config.get('min_ratio', 1)
+ syn = dataset_config.get('syn', False)
+ if syn:
+ data_dir_list = []
+ data_dir = '../training_aug_lmdb_noerror/ep' + str(epoch)
+ for dir_syn in os.listdir(data_dir):
+ data_dir_list.append(data_dir + '/' + dir_syn)
+ else:
+ data_dir_list = dataset_config['data_dir_list']
+ self.padding = dataset_config.get('padding', True)
+ self.padding_rand = dataset_config.get('padding_rand', False)
+ self.padding_doub = dataset_config.get('padding_doub', False)
+ self.do_shuffle = loader_config['shuffle']
+ self.seed = epoch
+ data_source_num = len(data_dir_list)
+ ratio_list = dataset_config.get('ratio_list', 1.0)
+ if isinstance(ratio_list, (float, int)):
+ ratio_list = [float(ratio_list)] * int(data_source_num)
+ assert (
+ len(ratio_list) == data_source_num
+ ), 'The length of ratio_list should be the same as the file_list.'
+ self.lmdb_sets = self.load_hierarchical_lmdb_dataset(
+ data_dir_list, ratio_list)
+ for data_dir in data_dir_list:
+ logger.info('Initialize indexs of datasets:%s' % data_dir)
+ self.logger = logger
+ self.data_idx_order_list = self.dataset_traversal()
+ wh_ratio = np.around(np.array(self.get_wh_ratio()))
+ self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio)
+ for i in range(max_ratio + 1):
+ logger.info((1 * (self.wh_ratio == i)).sum())
+ self.wh_ratio_sort = np.argsort(self.wh_ratio)
+ self.ops = create_operators(dataset_config['transforms'],
+ global_config)
+
+ self.need_reset = True in [x < 1 for x in ratio_list]
+ self.error = 0
+ self.base_shape = dataset_config.get(
+ 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
+ self.base_h = 32
+
+ def get_wh_ratio(self):
+ wh_ratio = []
+ for idx in range(self.data_idx_order_list.shape[0]):
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ wh_key = 'wh-%09d'.encode() % file_idx
+ wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key)
+ if wh is None:
+ img_key = f'image-{file_idx:09d}'.encode()
+ img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key)
+ buf = io.BytesIO(img)
+ w, h = Image.open(buf).size
+ else:
+ wh = wh.decode('utf-8')
+ w, h = wh.split('_')
+ wh_ratio.append(float(w) / float(h))
+ return wh_ratio
+
+ def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list):
+ lmdb_sets = {}
+ dataset_idx = 0
+ for dirpath, ratio in zip(data_dir_list, ratio_list):
+ env = lmdb.open(dirpath,
+ max_readers=32,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False)
+ txn = env.begin(write=False)
+ num_samples = int(txn.get('num-samples'.encode()))
+ lmdb_sets[dataset_idx] = {
+ 'dirpath': dirpath,
+ 'env': env,
+ 'txn': txn,
+ 'num_samples': num_samples,
+ 'ratio_num_samples': int(ratio * num_samples)
+ }
+ dataset_idx += 1
+ return lmdb_sets
+
+ def dataset_traversal(self):
+ lmdb_num = len(self.lmdb_sets)
+ total_sample_num = 0
+ for lno in range(lmdb_num):
+ total_sample_num += self.lmdb_sets[lno]['ratio_num_samples']
+ data_idx_order_list = np.zeros((total_sample_num, 2))
+ beg_idx = 0
+ for lno in range(lmdb_num):
+ tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples']
+ end_idx = beg_idx + tmp_sample_num
+ data_idx_order_list[beg_idx:end_idx, 0] = lno
+ data_idx_order_list[beg_idx:end_idx, 1] = list(
+ random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1),
+ self.lmdb_sets[lno]['ratio_num_samples']))
+ beg_idx = beg_idx + tmp_sample_num
+ return data_idx_order_list
+
+ def get_img_data(self, value):
+ """get_img_data."""
+ if not value:
+ return None
+ imgdata = np.frombuffer(value, dtype='uint8')
+ if imgdata is None:
+ return None
+ imgori = cv2.imdecode(imgdata, 1)
+ if imgori is None:
+ return None
+ return imgori
+
+ def resize_norm_img(self, data, gen_ratio, padding=True):
+ img = data['image']
+ h = img.shape[0]
+ w = img.shape[1]
+ if self.padding_rand and random.random() < 0.5:
+ padding = not padding
+ imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [
+ self.base_h * gen_ratio, self.base_h
+ ]
+ use_ratio = imgW // imgH
+ if use_ratio >= (w // h) + 2:
+ self.error += 1
+ return None
+ if not padding:
+ resized_image = cv2.resize(img, (imgW, imgH),
+ interpolation=cv2.INTER_LINEAR)
+ resized_w = imgW
+ else:
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(
+ math.ceil(imgH * ratio * (random.random() + 0.5)))
+ resized_w = min(imgW, resized_w)
+
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ resized_image = resized_image.astype('float32')
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ padding_im = np.zeros((3, imgH, imgW), dtype=np.float32)
+ if self.padding_doub and random.random() < 0.5:
+ padding_im[:, :, -resized_w:] = resized_image
+ else:
+ padding_im[:, :, :resized_w] = resized_image
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ data['image'] = padding_im
+ data['valid_ratio'] = valid_ratio
+ data['real_ratio'] = round(w / h)
+ return data
+
+ def get_lmdb_sample_info(self, txn, index):
+ label_key = 'label-%09d'.encode() % index
+ label = txn.get(label_key)
+ if label is None:
+ return None
+ label = label.decode('utf-8')
+ img_key = 'image-%09d'.encode() % index
+ imgbuf = txn.get(img_key)
+ return imgbuf, label
+
+ def __getitem__(self, properties):
+ img_width = properties[0]
+ img_height = properties[1]
+ idx = properties[2]
+ ratio = properties[3]
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ sample_info = self.get_lmdb_sample_info(
+ self.lmdb_sets[lmdb_idx]['txn'], file_idx)
+ if sample_info is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+ img, label = sample_info
+ data = {'image': img, 'label': label}
+ outs = transform(data, self.ops[:-1])
+ if outs is not None:
+ outs = self.resize_norm_img(outs, ratio, padding=self.padding)
+ if outs is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+ outs = transform(outs, self.ops[-1:])
+ if outs is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+ return outs
+
+ def __len__(self):
+ return self.data_idx_order_list.shape[0]
diff --git a/tools/data/ratio_dataset_test.py b/tools/data/ratio_dataset_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..18adfd66dfd763d86d56bf447a6056a547d195e4
--- /dev/null
+++ b/tools/data/ratio_dataset_test.py
@@ -0,0 +1,273 @@
+import io
+import math
+import random
+import re
+import unicodedata
+
+import cv2
+import lmdb
+import numpy as np
+from PIL import Image
+from torch.utils.data import Dataset
+
+from openrec.preprocess import create_operators, transform
+
+
+class CharsetAdapter:
+ """Transforms labels according to the target charset."""
+
+ def __init__(self, target_charset) -> None:
+ super().__init__()
+ self.lowercase_only = target_charset == target_charset.lower()
+ self.uppercase_only = target_charset == target_charset.upper()
+ self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
+
+ def __call__(self, label):
+ if self.lowercase_only:
+ label = label.lower()
+ elif self.uppercase_only:
+ label = label.upper()
+ # Remove unsupported characters
+ label = self.unsupported.sub('', label)
+ return label
+
+
+class RatioDataSetTest(Dataset):
+
+ def __init__(self, config, mode, logger, seed=None, epoch=1):
+ super(RatioDataSetTest, self).__init__()
+ self.ds_width = config[mode]['dataset'].get('ds_width', True)
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ loader_config = config[mode]['loader']
+ max_ratio = loader_config.get('max_ratio', 10)
+ min_ratio = loader_config.get('min_ratio', 1)
+ data_dir_list = dataset_config['data_dir_list']
+ self.do_shuffle = loader_config['shuffle']
+ self.seed = epoch
+ self.max_text_length = global_config['max_text_length']
+ data_source_num = len(data_dir_list)
+ ratio_list = dataset_config.get('ratio_list', 1.0)
+ if isinstance(ratio_list, (float, int)):
+ ratio_list = [float(ratio_list)] * int(data_source_num)
+ assert len(
+ ratio_list
+ ) == data_source_num, 'The length of ratio_list should be the same as the file_list.'
+ self.lmdb_sets = self.load_hierarchical_lmdb_dataset(
+ data_dir_list, ratio_list)
+ for data_dir in data_dir_list:
+ logger.info('Initialize indexs of datasets:%s' % data_dir)
+ self.logger = logger
+ data_idx_order_list = self.dataset_traversal()
+ character_dict_path = global_config.get('character_dict_path', None)
+ use_space_char = global_config.get('use_space_char', False)
+ if character_dict_path is None:
+ char_test = '0123456789abcdefghijklmnopqrstuvwxyz'
+ else:
+ char_test = ''
+ with open(character_dict_path, 'rb') as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip('\n').strip('\r\n')
+ char_test += line
+ if use_space_char:
+ char_test += ' '
+ wh_ratio, data_idx_order_list = self.get_wh_ratio(
+ data_idx_order_list, char_test)
+ self.data_idx_order_list = np.array(data_idx_order_list)
+ wh_ratio = np.around(np.array(wh_ratio))
+ self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio)
+ for i in range(max_ratio + 1):
+ logger.info((1 * (self.wh_ratio == i)).sum())
+ self.wh_ratio_sort = np.argsort(self.wh_ratio)
+ self.ops = create_operators(dataset_config['transforms'],
+ global_config)
+
+ self.need_reset = True in [x < 1 for x in ratio_list]
+ self.error = 0
+ self.base_shape = dataset_config.get(
+ 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
+ self.base_h = 32
+
+ def get_wh_ratio(self, data_idx_order_list, char_test):
+ wh_ratio = []
+ wh_ratio_len = [[0 for _ in range(26)] for _ in range(11)]
+ data_idx_order_list_filter = []
+ charset_adapter = CharsetAdapter(char_test)
+
+ for idx in range(data_idx_order_list.shape[0]):
+ lmdb_idx, file_idx = data_idx_order_list[idx]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ wh_key = 'wh-%09d'.encode() % file_idx
+ wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key)
+ if wh is None:
+ img_key = f'image-{file_idx:09d}'.encode()
+ img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key)
+ buf = io.BytesIO(img)
+ w, h = Image.open(buf).size
+ else:
+ wh = wh.decode('utf-8')
+ w, h = wh.split('_')
+
+ label_key = 'label-%09d'.encode() % file_idx
+ label = self.lmdb_sets[lmdb_idx]['txn'].get(label_key)
+ if label is not None:
+ # return None
+ label = label.decode('utf-8')
+ # if remove_whitespace:
+ label = ''.join(label.split())
+ # Normalize unicode composites (if any) and convert to compatible ASCII characters
+ # if normalize_unicode:
+ label = unicodedata.normalize('NFKD',
+ label).encode('ascii',
+ 'ignore').decode()
+ # Filter by length before removing unsupported characters. The original label might be too long.
+ if len(label) > self.max_text_length:
+ continue
+ label = charset_adapter(label)
+ if not label:
+ continue
+
+ wh_ratio.append(float(w) / float(h))
+ wh_ratio_len[int(float(w) /
+ float(h)) if int(float(w) /
+ float(h)) <= 10 else
+ 10][len(label) if len(label) <= 25 else 25] += 1
+ data_idx_order_list_filter.append([lmdb_idx, file_idx])
+ self.logger.info(wh_ratio_len)
+ return wh_ratio, data_idx_order_list_filter
+
+ def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list):
+ lmdb_sets = {}
+ dataset_idx = 0
+ for dirpath, ratio in zip(data_dir_list, ratio_list):
+ env = lmdb.open(dirpath,
+ max_readers=32,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False)
+ txn = env.begin(write=False)
+ num_samples = int(txn.get('num-samples'.encode()))
+ lmdb_sets[dataset_idx] = {
+ 'dirpath': dirpath,
+ 'env': env,
+ 'txn': txn,
+ 'num_samples': num_samples,
+ 'ratio_num_samples': int(ratio * num_samples),
+ }
+ dataset_idx += 1
+ return lmdb_sets
+
+ def dataset_traversal(self):
+ lmdb_num = len(self.lmdb_sets)
+ total_sample_num = 0
+ for lno in range(lmdb_num):
+ total_sample_num += self.lmdb_sets[lno]['ratio_num_samples']
+ data_idx_order_list = np.zeros((total_sample_num, 2))
+ beg_idx = 0
+ for lno in range(lmdb_num):
+ tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples']
+ end_idx = beg_idx + tmp_sample_num
+ data_idx_order_list[beg_idx:end_idx, 0] = lno
+ data_idx_order_list[beg_idx:end_idx, 1] = list(
+ random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1),
+ self.lmdb_sets[lno]['ratio_num_samples']))
+ beg_idx = beg_idx + tmp_sample_num
+ return data_idx_order_list
+
+ def get_img_data(self, value):
+ """get_img_data."""
+ if not value:
+ return None
+ imgdata = np.frombuffer(value, dtype='uint8')
+ if imgdata is None:
+ return None
+ imgori = cv2.imdecode(imgdata, 1)
+ if imgori is None:
+ return None
+ return imgori
+
+ def resize_norm_img(self, data, gen_ratio, padding=True):
+ img = data['image']
+ h = img.shape[0]
+ w = img.shape[1]
+
+ imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [
+ self.base_h * gen_ratio, self.base_h
+ ]
+ use_ratio = imgW // imgH
+ if use_ratio >= (w // h) + 2:
+ self.error += 1
+ return None
+ if not padding:
+ resized_image = cv2.resize(img, (imgW, imgH),
+ interpolation=cv2.INTER_LINEAR)
+ resized_w = imgW
+ else:
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(
+ math.ceil(imgH * ratio * (random.random() + 0.5)))
+ resized_w = min(imgW, resized_w)
+
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ resized_image = resized_image.astype('float32')
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ padding_im = np.zeros((3, imgH, imgW), dtype=np.float32)
+ padding_im[:, :, :resized_w] = resized_image
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ data['image'] = padding_im
+ data['valid_ratio'] = valid_ratio
+ data['gen_ratio'] = imgW // imgH
+ data['real_ratio'] = max(1, round(w / h))
+ return data
+
+ def get_lmdb_sample_info(self, txn, index):
+ label_key = 'label-%09d'.encode() % index
+ label = txn.get(label_key)
+ if label is None:
+ return None
+ label = label.decode('utf-8')
+ img_key = 'image-%09d'.encode() % index
+ imgbuf = txn.get(img_key)
+ return imgbuf, label
+
+ def __getitem__(self, properties):
+ img_width = properties[0]
+ img_height = properties[1]
+ idx = properties[2]
+ ratio = properties[3]
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ sample_info = self.get_lmdb_sample_info(
+ self.lmdb_sets[lmdb_idx]['txn'], file_idx)
+ if sample_info is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+ img, label = sample_info
+ data = {'image': img, 'label': label}
+ outs = transform(data, self.ops[:-1])
+ if outs is not None:
+ outs = self.resize_norm_img(outs, ratio, padding=False)
+ if outs is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+
+ outs = transform(outs, self.ops[-1:])
+ if outs is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+ return outs
+
+ def __len__(self):
+ return self.data_idx_order_list.shape[0]
diff --git a/tools/data/ratio_dataset_tvresize.py b/tools/data/ratio_dataset_tvresize.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0f0469d2ba8069908a6fa06175c765ef32d1d03
--- /dev/null
+++ b/tools/data/ratio_dataset_tvresize.py
@@ -0,0 +1,213 @@
+import io
+import math
+import random
+
+import cv2
+import lmdb
+import numpy as np
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms as T
+from torchvision.transforms import functional as F
+
+from openrec.preprocess import create_operators, transform
+
+
+class RatioDataSetTVResize(Dataset):
+
+ def __init__(self, config, mode, logger, seed=None, epoch=1):
+ super(RatioDataSetTVResize, self).__init__()
+ self.ds_width = config[mode]['dataset'].get('ds_width', True)
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ loader_config = config[mode]['loader']
+ max_ratio = loader_config.get('max_ratio', 10)
+ min_ratio = loader_config.get('min_ratio', 1)
+ data_dir_list = dataset_config['data_dir_list']
+ self.padding = dataset_config.get('padding', True)
+ self.padding_rand = dataset_config.get('padding_rand', False)
+ self.padding_doub = dataset_config.get('padding_doub', False)
+ self.do_shuffle = loader_config['shuffle']
+ self.seed = epoch
+ data_source_num = len(data_dir_list)
+ ratio_list = dataset_config.get('ratio_list', 1.0)
+ if isinstance(ratio_list, (float, int)):
+ ratio_list = [float(ratio_list)] * int(data_source_num)
+ assert (
+ len(ratio_list) == data_source_num
+ ), 'The length of ratio_list should be the same as the file_list.'
+ self.lmdb_sets = self.load_hierarchical_lmdb_dataset(
+ data_dir_list, ratio_list)
+ for data_dir in data_dir_list:
+ logger.info('Initialize indexs of datasets:%s' % data_dir)
+ self.logger = logger
+ self.data_idx_order_list = self.dataset_traversal()
+ wh_ratio = np.around(np.array(self.get_wh_ratio()))
+ self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio)
+ for i in range(max_ratio + 1):
+ logger.info((1 * (self.wh_ratio == i)).sum())
+ self.wh_ratio_sort = np.argsort(self.wh_ratio)
+ self.ops = create_operators(dataset_config['transforms'],
+ global_config)
+
+ self.need_reset = True in [x < 1 for x in ratio_list]
+ self.error = 0
+ self.base_shape = dataset_config.get(
+ 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
+ self.base_h = dataset_config.get('base_h', 32)
+ self.interpolation = T.InterpolationMode.BICUBIC
+ transforms = []
+ transforms.extend([
+ T.ToTensor(),
+ T.Normalize(0.5, 0.5),
+ ])
+ self.transforms = T.Compose(transforms)
+
+ def get_wh_ratio(self):
+ wh_ratio = []
+ for idx in range(self.data_idx_order_list.shape[0]):
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ wh_key = 'wh-%09d'.encode() % file_idx
+ wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key)
+ if wh is None:
+ img_key = f'image-{file_idx:09d}'.encode()
+ img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key)
+ buf = io.BytesIO(img)
+ w, h = Image.open(buf).size
+ else:
+ wh = wh.decode('utf-8')
+ w, h = wh.split('_')
+ wh_ratio.append(float(w) / float(h))
+ return wh_ratio
+
+ def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list):
+ lmdb_sets = {}
+ dataset_idx = 0
+ for dirpath, ratio in zip(data_dir_list, ratio_list):
+ env = lmdb.open(dirpath,
+ max_readers=32,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False)
+ txn = env.begin(write=False)
+ num_samples = int(txn.get('num-samples'.encode()))
+ lmdb_sets[dataset_idx] = {
+ 'dirpath': dirpath,
+ 'env': env,
+ 'txn': txn,
+ 'num_samples': num_samples,
+ 'ratio_num_samples': int(ratio * num_samples)
+ }
+ dataset_idx += 1
+ return lmdb_sets
+
+ def dataset_traversal(self):
+ lmdb_num = len(self.lmdb_sets)
+ total_sample_num = 0
+ for lno in range(lmdb_num):
+ total_sample_num += self.lmdb_sets[lno]['ratio_num_samples']
+ data_idx_order_list = np.zeros((total_sample_num, 2))
+ beg_idx = 0
+ for lno in range(lmdb_num):
+ tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples']
+ end_idx = beg_idx + tmp_sample_num
+ data_idx_order_list[beg_idx:end_idx, 0] = lno
+ data_idx_order_list[beg_idx:end_idx, 1] = list(
+ random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1),
+ self.lmdb_sets[lno]['ratio_num_samples']))
+ beg_idx = beg_idx + tmp_sample_num
+ return data_idx_order_list
+
+ def get_img_data(self, value):
+ """get_img_data."""
+ if not value:
+ return None
+ imgdata = np.frombuffer(value, dtype='uint8')
+ if imgdata is None:
+ return None
+ imgori = cv2.imdecode(imgdata, 1)
+ if imgori is None:
+ return None
+ return imgori
+
+ def resize_norm_img(self, data, gen_ratio, padding=True):
+ img = data['image']
+ w, h = img.size
+ if self.padding_rand and random.random() < 0.5:
+ padding = not padding
+ imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [
+ self.base_h * gen_ratio, self.base_h
+ ]
+ use_ratio = imgW // imgH
+ if use_ratio >= (w // h) + 2:
+ self.error += 1
+ return None
+ if not padding:
+ resized_w = imgW
+ else:
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(
+ math.ceil(imgH * ratio * (random.random() + 0.5)))
+ resized_w = min(imgW, resized_w)
+ resized_image = F.resize(img, (imgH, resized_w),
+ interpolation=self.interpolation)
+ img = self.transforms(resized_image)
+ if resized_w < imgW and padding:
+ # img = F.pad(img, [0, 0, imgW-resized_w, 0], fill=0.)
+ if self.padding_doub and random.random() < 0.5:
+ img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.)
+ else:
+ img = F.pad(img, [imgW - resized_w, 0, 0, 0], fill=0.)
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ data['image'] = img
+ data['valid_ratio'] = valid_ratio
+ return data
+
+ def get_lmdb_sample_info(self, txn, index):
+ label_key = 'label-%09d'.encode() % index
+ label = txn.get(label_key)
+ if label is None:
+ return None
+ label = label.decode('utf-8')
+ img_key = 'image-%09d'.encode() % index
+ imgbuf = txn.get(img_key)
+ return imgbuf, label
+
+ def __getitem__(self, properties):
+ img_width = properties[0]
+ img_height = properties[1]
+ idx = properties[2]
+ ratio = properties[3]
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ sample_info = self.get_lmdb_sample_info(
+ self.lmdb_sets[lmdb_idx]['txn'], file_idx)
+ if sample_info is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+ img, label = sample_info
+ data = {'image': img, 'label': label}
+ outs = transform(data, self.ops[:-1])
+ if outs is not None:
+ outs = self.resize_norm_img(outs, ratio, padding=self.padding)
+ if outs is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+ outs = transform(outs, self.ops[-1:])
+ if outs is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+ return outs
+
+ def __len__(self):
+ return self.data_idx_order_list.shape[0]
diff --git a/tools/data/ratio_dataset_tvresize_test.py b/tools/data/ratio_dataset_tvresize_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..abf6ff566b2ad7c2e4a32be8949297a9e81d5318
--- /dev/null
+++ b/tools/data/ratio_dataset_tvresize_test.py
@@ -0,0 +1,276 @@
+import io
+import math
+import random
+import re
+import unicodedata
+
+import cv2
+import lmdb
+import numpy as np
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms as T
+from torchvision.transforms import functional as F
+
+from openrec.preprocess import create_operators, transform
+
+
+class CharsetAdapter:
+ """Transforms labels according to the target charset."""
+
+ def __init__(self, target_charset) -> None:
+ super().__init__()
+ self.lowercase_only = target_charset == target_charset.lower()
+ self.uppercase_only = target_charset == target_charset.upper()
+ self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
+
+ def __call__(self, label):
+ if self.lowercase_only:
+ label = label.lower()
+ elif self.uppercase_only:
+ label = label.upper()
+ # Remove unsupported characters
+ label = self.unsupported.sub('', label)
+ return label
+
+
+class RatioDataSetTVResizeTest(Dataset):
+
+ def __init__(self, config, mode, logger, seed=None, epoch=1):
+ super(RatioDataSetTVResizeTest, self).__init__()
+ self.ds_width = config[mode]['dataset'].get('ds_width', True)
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ loader_config = config[mode]['loader']
+ max_ratio = loader_config.get('max_ratio', 10)
+ min_ratio = loader_config.get('min_ratio', 1)
+ data_dir_list = dataset_config['data_dir_list']
+ self.do_shuffle = loader_config['shuffle']
+ self.seed = epoch
+ self.max_text_length = global_config['max_text_length']
+ data_source_num = len(data_dir_list)
+ ratio_list = dataset_config.get('ratio_list', 1.0)
+ if isinstance(ratio_list, (float, int)):
+ ratio_list = [float(ratio_list)] * int(data_source_num)
+ assert len(
+ ratio_list
+ ) == data_source_num, 'The length of ratio_list should be the same as the file_list.'
+ self.lmdb_sets = self.load_hierarchical_lmdb_dataset(
+ data_dir_list, ratio_list)
+ for data_dir in data_dir_list:
+ logger.info('Initialize indexs of datasets:%s' % data_dir)
+ self.logger = logger
+ data_idx_order_list = self.dataset_traversal()
+ character_dict_path = global_config.get('character_dict_path', None)
+ use_space_char = global_config.get('use_space_char', False)
+ if character_dict_path is None:
+ char_test = '0123456789abcdefghijklmnopqrstuvwxyz'
+ else:
+ char_test = ''
+ with open(character_dict_path, 'rb') as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip('\n').strip('\r\n')
+ char_test += line
+ if use_space_char:
+ char_test += ' '
+ wh_ratio, data_idx_order_list = self.get_wh_ratio(
+ data_idx_order_list, char_test)
+ self.data_idx_order_list = np.array(data_idx_order_list)
+ wh_ratio = np.around(np.array(wh_ratio))
+ self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio)
+ for i in range(max_ratio + 1):
+ logger.info((1 * (self.wh_ratio == i)).sum())
+ self.wh_ratio_sort = np.argsort(self.wh_ratio)
+ self.ops = create_operators(dataset_config['transforms'],
+ global_config)
+
+ self.need_reset = True in [x < 1 for x in ratio_list]
+ self.error = 0
+ self.base_shape = dataset_config.get(
+ 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
+ self.base_h = dataset_config.get('base_h', 32)
+ self.interpolation = T.InterpolationMode.BICUBIC
+ transforms = []
+ transforms.extend([
+ T.ToTensor(),
+ T.Normalize(0.5, 0.5),
+ ])
+ self.transforms = T.Compose(transforms)
+
+ def get_wh_ratio(self, data_idx_order_list, char_test):
+ wh_ratio = []
+ wh_ratio_len = [[0 for _ in range(26)] for _ in range(11)]
+ data_idx_order_list_filter = []
+ charset_adapter = CharsetAdapter(char_test)
+
+ for idx in range(data_idx_order_list.shape[0]):
+ lmdb_idx, file_idx = data_idx_order_list[idx]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ wh_key = 'wh-%09d'.encode() % file_idx
+ wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key)
+ if wh is None:
+ img_key = f'image-{file_idx:09d}'.encode()
+ img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key)
+ buf = io.BytesIO(img)
+ w, h = Image.open(buf).size
+ else:
+ wh = wh.decode('utf-8')
+ w, h = wh.split('_')
+
+ label_key = 'label-%09d'.encode() % file_idx
+ label = self.lmdb_sets[lmdb_idx]['txn'].get(label_key)
+ if label is not None:
+ # return None
+ label = label.decode('utf-8')
+ # if remove_whitespace:
+ label = ''.join(label.split())
+ # Normalize unicode composites (if any) and convert to compatible ASCII characters
+ # if normalize_unicode:
+ label = unicodedata.normalize('NFKD',
+ label).encode('ascii',
+ 'ignore').decode()
+ # Filter by length before removing unsupported characters. The original label might be too long.
+ if len(label) > self.max_text_length:
+ continue
+ label = charset_adapter(label)
+ if not label:
+ continue
+
+ wh_ratio.append(float(w) / float(h))
+ wh_ratio_len[int(float(w) /
+ float(h)) if int(float(w) /
+ float(h)) <= 10 else
+ 10][len(label) if len(label) <= 25 else 25] += 1
+ data_idx_order_list_filter.append([lmdb_idx, file_idx])
+ self.logger.info(wh_ratio_len)
+ return wh_ratio, data_idx_order_list_filter
+
+ def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list):
+ lmdb_sets = {}
+ dataset_idx = 0
+ for dirpath, ratio in zip(data_dir_list, ratio_list):
+ env = lmdb.open(dirpath,
+ max_readers=32,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False)
+ txn = env.begin(write=False)
+ num_samples = int(txn.get('num-samples'.encode()))
+ lmdb_sets[dataset_idx] = {
+ 'dirpath': dirpath,
+ 'env': env,
+ 'txn': txn,
+ 'num_samples': num_samples,
+ 'ratio_num_samples': int(ratio * num_samples),
+ }
+ dataset_idx += 1
+ return lmdb_sets
+
+ def dataset_traversal(self):
+ lmdb_num = len(self.lmdb_sets)
+ total_sample_num = 0
+ for lno in range(lmdb_num):
+ total_sample_num += self.lmdb_sets[lno]['ratio_num_samples']
+ data_idx_order_list = np.zeros((total_sample_num, 2))
+ beg_idx = 0
+ for lno in range(lmdb_num):
+ tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples']
+ end_idx = beg_idx + tmp_sample_num
+ data_idx_order_list[beg_idx:end_idx, 0] = lno
+ data_idx_order_list[beg_idx:end_idx, 1] = list(
+ random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1),
+ self.lmdb_sets[lno]['ratio_num_samples']))
+ beg_idx = beg_idx + tmp_sample_num
+ return data_idx_order_list
+
+ def get_img_data(self, value):
+ """get_img_data."""
+ if not value:
+ return None
+ imgdata = np.frombuffer(value, dtype='uint8')
+ if imgdata is None:
+ return None
+ imgori = cv2.imdecode(imgdata, 1)
+ if imgori is None:
+ return None
+ return imgori
+
+ def resize_norm_img(self, data, gen_ratio, padding=True):
+ img = data['image']
+ w, h = img.size
+ imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [
+ self.base_h * gen_ratio, self.base_h
+ ]
+ use_ratio = imgW // imgH
+ if use_ratio >= (w // h) + 2:
+ self.error += 1
+ return None
+ if not padding:
+ resized_w = imgW
+ else:
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(
+ math.ceil(imgH * ratio * (random.random() + 0.5)))
+ resized_w = min(imgW, resized_w)
+ resized_image = F.resize(img, (imgH, resized_w),
+ interpolation=self.interpolation)
+ img = self.transforms(resized_image)
+ if resized_w < imgW and padding:
+ img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.)
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ data['image'] = img
+ data['valid_ratio'] = valid_ratio
+ data['gen_ratio'] = imgW // imgH
+ r = float(w) / float(h)
+ data['real_ratio'] = max(1, round(r))
+ return data
+
+ def get_lmdb_sample_info(self, txn, index):
+ label_key = 'label-%09d'.encode() % index
+ label = txn.get(label_key)
+ if label is None:
+ return None
+ label = label.decode('utf-8')
+ img_key = 'image-%09d'.encode() % index
+ imgbuf = txn.get(img_key)
+ return imgbuf, label
+
+ def __getitem__(self, properties):
+ img_width = properties[0]
+ img_height = properties[1]
+ idx = properties[2]
+ ratio = properties[3]
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ sample_info = self.get_lmdb_sample_info(
+ self.lmdb_sets[lmdb_idx]['txn'], file_idx)
+ if sample_info is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+ img, label = sample_info
+ data = {'image': img, 'label': label}
+ outs = transform(data, self.ops[:-1])
+ if outs is not None:
+ outs = self.resize_norm_img(outs, ratio, padding=False)
+ if outs is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+
+ outs = transform(outs, self.ops[-1:])
+ if outs is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+ return outs
+
+ def __len__(self):
+ return self.data_idx_order_list.shape[0]
diff --git a/tools/data/ratio_sampler.py b/tools/data/ratio_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0c9d720b5835092a1ab4370db798a569b863ab6
--- /dev/null
+++ b/tools/data/ratio_sampler.py
@@ -0,0 +1,190 @@
+import math
+import os
+import random
+
+import numpy as np
+import torch
+from torch.utils.data import Sampler
+
+
+class RatioSampler(Sampler):
+
+ def __init__(self,
+ data_source,
+ scales,
+ first_bs=512,
+ fix_bs=True,
+ divided_factor=[8, 16],
+ is_training=True,
+ max_ratio=10,
+ max_bs=1024,
+ seed=None):
+ """
+ multi scale samper
+ Args:
+ data_source(dataset)
+ scales(list): several scales for image resolution
+ first_bs(int): batch size for the first scale in scales
+ divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
+ is_training(boolean): mode
+ """
+ # min. and max. spatial dimensions
+ self.data_source = data_source
+ # self.data_idx_order_list = np.array(data_source.data_idx_order_list)
+ self.ds_width = data_source.ds_width
+ self.seed = data_source.seed
+ if self.ds_width:
+ self.wh_ratio = data_source.wh_ratio
+ self.wh_ratio_sort = data_source.wh_ratio_sort
+ self.n_data_samples = len(self.data_source)
+ self.max_ratio = max_ratio
+ self.max_bs = max_bs
+
+ if isinstance(scales[0], list):
+ width_dims = [i[0] for i in scales]
+ height_dims = [i[1] for i in scales]
+ elif isinstance(scales[0], int):
+ width_dims = scales
+ height_dims = scales
+ base_im_w = width_dims[0]
+ base_im_h = height_dims[0]
+ base_batch_size = first_bs
+ base_elements = base_im_w * base_im_h * base_batch_size
+ self.base_elements = base_elements
+ self.base_batch_size = base_batch_size
+ self.base_im_h = base_im_h
+ self.base_im_w = base_im_w
+
+ # Get the GPU and node related information
+ num_replicas = torch.cuda.device_count()
+ # rank = dist.get_rank()
+ rank = (int(os.environ['LOCAL_RANK'])
+ if 'LOCAL_RANK' in os.environ else 0)
+ # self.rank = rank
+ # adjust the total samples to avoid batch dropping
+ num_samples_per_replica = int(
+ math.ceil(self.n_data_samples * 1.0 / num_replicas))
+
+ img_indices = [idx for idx in range(self.n_data_samples)]
+ self.shuffle = False
+ if is_training:
+ # compute the spatial dimensions and corresponding batch size
+ # ImageNet models down-sample images by a factor of 32.
+ # Ensure that width and height dimensions are multiples are multiple of 32.
+ width_dims = [
+ int((w // divided_factor[0]) * divided_factor[0])
+ for w in width_dims
+ ]
+ height_dims = [
+ int((h // divided_factor[1]) * divided_factor[1])
+ for h in height_dims
+ ]
+
+ img_batch_pairs = list()
+ for (h, w) in zip(height_dims, width_dims):
+ if fix_bs:
+ batch_size = base_batch_size
+ else:
+ batch_size = int(max(1, (base_elements / (h * w))))
+ img_batch_pairs.append((w, h, batch_size))
+ self.img_batch_pairs = img_batch_pairs
+ self.shuffle = True
+ np.random.seed(seed)
+ random.seed(seed)
+ else:
+ self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)]
+
+ self.img_indices = img_indices
+ self.n_samples_per_replica = num_samples_per_replica
+ self.epoch = 0
+ self.rank = rank
+ self.num_replicas = num_replicas
+
+ # self.batch_list = []
+ self.current = 0
+ self.is_training = is_training
+ if is_training:
+ indices_rank_i = self.img_indices[
+ self.rank:len(self.img_indices):self.num_replicas]
+ else:
+ indices_rank_i = self.img_indices
+ self.indices_rank_i_ori = np.array(self.wh_ratio_sort[indices_rank_i])
+ self.indices_rank_i_ratio = self.wh_ratio[self.indices_rank_i_ori]
+ indices_rank_i_ratio_unique = np.unique(self.indices_rank_i_ratio)
+ self.indices_rank_i_ratio_unique = indices_rank_i_ratio_unique.tolist()
+ self.batch_list = self.create_batch()
+ self.length = len(self.batch_list)
+ self.batchs_in_one_epoch_id = [i for i in range(self.length)]
+
+ def create_batch(self):
+ batch_list = []
+ for ratio in self.indices_rank_i_ratio_unique:
+ ratio_ids = np.where(self.indices_rank_i_ratio == ratio)[0]
+ ratio_ids = self.indices_rank_i_ori[ratio_ids]
+ if self.shuffle:
+ random.shuffle(ratio_ids)
+ num_ratio = ratio_ids.shape[0]
+ if ratio < 5:
+ batch_size_ratio = self.base_batch_size
+ else:
+ batch_size_ratio = min(
+ self.max_bs,
+ int(
+ max(1, (self.base_elements /
+ (self.base_im_h * ratio * self.base_im_h)))))
+ if num_ratio > batch_size_ratio:
+ batch_num_ratio = num_ratio // batch_size_ratio
+ print(self.rank, num_ratio, ratio * self.base_im_h,
+ batch_num_ratio, batch_size_ratio)
+ ratio_ids_full = ratio_ids[:batch_num_ratio *
+ batch_size_ratio].reshape(
+ batch_num_ratio,
+ batch_size_ratio, 1)
+ w = np.full_like(ratio_ids_full, ratio * self.base_im_h)
+ h = np.full_like(ratio_ids_full, self.base_im_h)
+ ra_wh = np.full_like(ratio_ids_full, ratio)
+ ratio_ids_full = np.concatenate([w, h, ratio_ids_full, ra_wh],
+ axis=-1)
+ batch_ratio = ratio_ids_full.tolist()
+
+ if batch_num_ratio * batch_size_ratio < num_ratio:
+ drop = ratio_ids[batch_num_ratio * batch_size_ratio:]
+ if self.is_training:
+ drop_full = ratio_ids[:batch_size_ratio - (
+ num_ratio - batch_num_ratio * batch_size_ratio)]
+ drop = np.append(drop_full, drop)
+ drop = drop.reshape(-1, 1)
+ w = np.full_like(drop, ratio * self.base_im_h)
+ h = np.full_like(drop, self.base_im_h)
+ ra_wh = np.full_like(drop, ratio)
+
+ drop = np.concatenate([w, h, drop, ra_wh], axis=-1)
+
+ batch_ratio.append(drop.tolist())
+ batch_list += batch_ratio
+ else:
+ print(self.rank, num_ratio, ratio * self.base_im_h,
+ batch_size_ratio)
+ ratio_ids = ratio_ids.reshape(-1, 1)
+ w = np.full_like(ratio_ids, ratio * self.base_im_h)
+ h = np.full_like(ratio_ids, self.base_im_h)
+ ra_wh = np.full_like(ratio_ids, ratio)
+
+ ratio_ids = np.concatenate([w, h, ratio_ids, ra_wh], axis=-1)
+ batch_list.append(ratio_ids.tolist())
+ return batch_list
+
+ def __iter__(self):
+ if self.shuffle or self.is_training:
+ random.seed(self.epoch)
+ self.epoch += 1
+ self.batch_list = self.create_batch()
+ random.shuffle(self.batchs_in_one_epoch_id)
+ for batch_tuple_id in self.batchs_in_one_epoch_id:
+ yield self.batch_list[batch_tuple_id]
+
+ def set_epoch(self, epoch: int):
+ self.epoch = epoch
+
+ def __len__(self):
+ return self.length
diff --git a/tools/data/simple_dataset.py b/tools/data/simple_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e793041642621b1d20c06bc773e4f89334a76b2
--- /dev/null
+++ b/tools/data/simple_dataset.py
@@ -0,0 +1,263 @@
+import json
+import math
+import os
+import random
+import traceback
+
+import cv2
+import numpy as np
+from torch.utils.data import Dataset
+
+from openrec.preprocess import create_operators, transform
+
+
+class SimpleDataSet(Dataset):
+
+ def __init__(self, config, mode, logger, seed=None, epoch=0):
+ super(SimpleDataSet, self).__init__()
+ self.logger = logger
+ self.mode = mode.lower()
+
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ loader_config = config[mode]['loader']
+
+ self.delimiter = dataset_config.get('delimiter', '\t')
+ label_file_list = dataset_config.pop('label_file_list')
+ data_source_num = len(label_file_list)
+ ratio_list = dataset_config.get('ratio_list', 1.0)
+ if isinstance(ratio_list, (float, int)):
+ ratio_list = [float(ratio_list)] * int(data_source_num)
+
+ assert len(
+ ratio_list
+ ) == data_source_num, 'The length of ratio_list should be the same as the file_list.'
+ self.data_dir = dataset_config['data_dir']
+ self.do_shuffle = loader_config['shuffle']
+ self.seed = seed
+ logger.info(f'Initialize indexs of datasets: {label_file_list}')
+ self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
+ self.data_idx_order_list = list(range(len(self.data_lines)))
+ if self.mode == 'train' and self.do_shuffle:
+ self.shuffle_data_random()
+
+ self.set_epoch_as_seed(self.seed, dataset_config)
+
+ self.ops = create_operators(dataset_config['transforms'],
+ global_config)
+ self.ext_op_transform_idx = dataset_config.get('ext_op_transform_idx',
+ 2)
+ self.need_reset = True in [x < 1 for x in ratio_list]
+
+ def set_epoch_as_seed(self, seed, dataset_config):
+ if self.mode == 'train':
+ try:
+ border_map_id = [
+ index for index, dictionary in enumerate(
+ dataset_config['transforms'])
+ if 'MakeBorderMap' in dictionary
+ ][0]
+ shrink_map_id = [
+ index for index, dictionary in enumerate(
+ dataset_config['transforms'])
+ if 'MakeShrinkMap' in dictionary
+ ][0]
+ dataset_config['transforms'][border_map_id]['MakeBorderMap'][
+ 'epoch'] = seed if seed is not None else 0
+ dataset_config['transforms'][shrink_map_id]['MakeShrinkMap'][
+ 'epoch'] = seed if seed is not None else 0
+ except Exception:
+ return
+
+ def get_image_info_list(self, file_list, ratio_list):
+ if isinstance(file_list, str):
+ file_list = [file_list]
+ data_lines = []
+ for idx, file in enumerate(file_list):
+ with open(file, 'rb') as f:
+ lines = f.readlines()
+ if self.mode == 'train' or ratio_list[idx] < 1.0:
+ random.seed(self.seed)
+ lines = random.sample(lines,
+ round(len(lines) * ratio_list[idx]))
+ data_lines.extend(lines)
+ return data_lines
+
+ def shuffle_data_random(self):
+ random.seed(self.seed)
+ random.shuffle(self.data_lines)
+ return
+
+ def _try_parse_filename_list(self, file_name):
+ # multiple images -> one gt label
+ if len(file_name) > 0 and file_name[0] == '[':
+ try:
+ info = json.loads(file_name)
+ file_name = random.choice(info)
+ except:
+ pass
+ return file_name
+
+ def get_ext_data(self):
+ ext_data_num = 0
+ for op in self.ops:
+ if hasattr(op, 'ext_data_num'):
+ ext_data_num = getattr(op, 'ext_data_num')
+ break
+ load_data_ops = self.ops[:self.ext_op_transform_idx]
+ ext_data = []
+
+ while len(ext_data) < ext_data_num:
+ file_idx = self.data_idx_order_list[np.random.randint(
+ self.__len__())]
+ data_line = self.data_lines[file_idx]
+ data_line = data_line.decode('utf-8')
+ substr = data_line.strip('\n').split(self.delimiter)
+ file_name = substr[0]
+ file_name = self._try_parse_filename_list(file_name)
+ label = substr[1]
+ img_path = os.path.join(self.data_dir, file_name)
+ data = {'img_path': img_path, 'label': label}
+ if not os.path.exists(img_path):
+ continue
+ with open(data['img_path'], 'rb') as f:
+ img = f.read()
+ data['image'] = img
+ data = transform(data, load_data_ops)
+
+ if data is None:
+ continue
+ if 'polys' in data.keys():
+ if data['polys'].shape[1] != 4:
+ continue
+ ext_data.append(data)
+ return ext_data
+
+ def __getitem__(self, idx):
+ file_idx = self.data_idx_order_list[idx]
+ data_line = self.data_lines[file_idx]
+ try:
+ data_line = data_line.decode('utf-8')
+ substr = data_line.strip('\n').split(self.delimiter)
+ file_name = substr[0]
+ file_name = self._try_parse_filename_list(file_name)
+ label = substr[1]
+ img_path = os.path.join(self.data_dir, file_name)
+ data = {'img_path': img_path, 'label': label}
+
+ if not os.path.exists(img_path):
+ raise Exception('{} does not exist!'.format(img_path))
+ with open(data['img_path'], 'rb') as f:
+ img = f.read()
+ data['image'] = img
+ data['ext_data'] = self.get_ext_data()
+ outs = transform(data, self.ops)
+ except:
+ self.logger.error(
+ 'When parsing line {}, error happened with msg: {}'.format(
+ data_line, traceback.format_exc()))
+ outs = None
+ if outs is None:
+ # during evaluation, we should fix the idx to get same results for many times of evaluation.
+ rnd_idx = np.random.randint(self.__len__(
+ )) if self.mode == 'train' else (idx + 1) % self.__len__()
+ return self.__getitem__(rnd_idx)
+ return outs
+
+ def __len__(self):
+ return len(self.data_idx_order_list)
+
+
+class MultiScaleDataSet(SimpleDataSet):
+
+ def __init__(self, config, mode, logger, seed=None):
+ super(MultiScaleDataSet, self).__init__(config, mode, logger, seed)
+ self.ds_width = config[mode]['dataset'].get('ds_width', False)
+ if self.ds_width:
+ self.wh_aware()
+
+ def wh_aware(self):
+ data_line_new = []
+ wh_ratio = []
+ for lins in self.data_lines:
+ data_line_new.append(lins)
+ lins = lins.decode('utf-8')
+ name, label, w, h = lins.strip('\n').split(self.delimiter)
+ wh_ratio.append(float(w) / float(h))
+
+ self.data_lines = data_line_new
+ self.wh_ratio = np.array(wh_ratio)
+ self.wh_ratio_sort = np.argsort(self.wh_ratio)
+ self.data_idx_order_list = list(range(len(self.data_lines)))
+
+ def resize_norm_img(self, data, imgW, imgH, padding=True):
+ img = data['image']
+ h = img.shape[0]
+ w = img.shape[1]
+ if not padding:
+ resized_image = cv2.resize(img, (imgW, imgH),
+ interpolation=cv2.INTER_LINEAR)
+ resized_w = imgW
+ else:
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ resized_image = resized_image.astype('float32')
+
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ padding_im = np.zeros((3, imgH, imgW), dtype=np.float32)
+ padding_im[:, :, :resized_w] = resized_image
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ data['image'] = padding_im
+ data['valid_ratio'] = valid_ratio
+ return data
+
+ def __getitem__(self, properties):
+ # properites is a tuple, contains (width, height, index)
+ img_height = properties[1]
+ idx = properties[2]
+ if self.ds_width and properties[3] is not None:
+ wh_ratio = properties[3]
+ img_width = img_height * (1 if int(round(wh_ratio)) == 0 else int(
+ round(wh_ratio)))
+ file_idx = self.wh_ratio_sort[idx]
+ else:
+ file_idx = self.data_idx_order_list[idx]
+ img_width = properties[0]
+ wh_ratio = None
+
+ data_line = self.data_lines[file_idx]
+ try:
+ data_line = data_line.decode('utf-8')
+ substr = data_line.strip('\n').split(self.delimiter)
+ file_name = substr[0]
+ file_name = self._try_parse_filename_list(file_name)
+ label = substr[1]
+ img_path = os.path.join(self.data_dir, file_name)
+ data = {'img_path': img_path, 'label': label}
+ if not os.path.exists(img_path):
+ raise Exception('{} does not exist!'.format(img_path))
+ with open(data['img_path'], 'rb') as f:
+ img = f.read()
+ data['image'] = img
+ data['ext_data'] = self.get_ext_data()
+ outs = transform(data, self.ops[:-1])
+ if outs is not None:
+ outs = self.resize_norm_img(outs, img_width, img_height)
+ outs = transform(outs, self.ops[-1:])
+ except:
+ self.logger.error(
+ 'When parsing line {}, error happened with msg: {}'.format(
+ data_line, traceback.format_exc()))
+ outs = None
+ if outs is None:
+ # during evaluation, we should fix the idx to get same results for many times of evaluation.
+ rnd_idx = np.random.randint(self.__len__(
+ )) if self.mode == 'train' else (idx + 1) % self.__len__()
+ return self.__getitem__([img_width, img_height, rnd_idx, wh_ratio])
+ return outs
diff --git a/tools/data/strlmdb_dataset.py b/tools/data/strlmdb_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..76ce4398b0de71ac67201cea5759e1fec64831d5
--- /dev/null
+++ b/tools/data/strlmdb_dataset.py
@@ -0,0 +1,143 @@
+import os
+
+import cv2
+import lmdb
+import numpy as np
+from torch.utils.data import Dataset
+
+from openrec.preprocess import create_operators, transform
+
+
+class STRLMDBDataSet(Dataset):
+
+ def __init__(self, config, mode, logger, seed=None, epoch=1, gpu_i=0):
+ super(STRLMDBDataSet, self).__init__()
+
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ loader_config = config[mode]['loader']
+ loader_config['batch_size_per_card']
+ # data_dir = dataset_config['data_dir']
+ data_dir = '../training_aug_lmdb_noerror/ep' + str(epoch)
+ self.do_shuffle = loader_config['shuffle']
+
+ self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
+ logger.info('Initialize indexs of datasets:%s' % data_dir)
+ self.data_idx_order_list = self.dataset_traversal()
+ if self.do_shuffle:
+ np.random.shuffle(self.data_idx_order_list)
+ self.ops = create_operators(dataset_config['transforms'],
+ global_config)
+ self.ext_op_transform_idx = dataset_config.get('ext_op_transform_idx',
+ 1)
+
+ dataset_config.get('ratio_list', [1.0])
+ self.need_reset = True # in [x < 1 for x in ratio_list]
+
+ def load_hierarchical_lmdb_dataset(self, data_dir):
+ lmdb_sets = {}
+ dataset_idx = 0
+ for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
+ if not dirnames:
+ env = lmdb.open(
+ dirpath,
+ max_readers=32,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False,
+ )
+ txn = env.begin(write=False)
+ num_samples = int(txn.get('num-samples'.encode()))
+ lmdb_sets[dataset_idx] = {
+ 'dirpath': dirpath,
+ 'env': env,
+ 'txn': txn,
+ 'num_samples': num_samples,
+ }
+ dataset_idx += 1
+ return lmdb_sets
+
+ def dataset_traversal(self):
+ lmdb_num = len(self.lmdb_sets)
+ total_sample_num = 0
+ for lno in range(lmdb_num):
+ total_sample_num += self.lmdb_sets[lno]['num_samples']
+ data_idx_order_list = np.zeros((total_sample_num, 2))
+ beg_idx = 0
+ for lno in range(lmdb_num):
+ tmp_sample_num = self.lmdb_sets[lno]['num_samples']
+ end_idx = beg_idx + tmp_sample_num
+ data_idx_order_list[beg_idx:end_idx, 0] = lno
+ data_idx_order_list[beg_idx:end_idx,
+ 1] = list(range(tmp_sample_num))
+ data_idx_order_list[beg_idx:end_idx, 1] += 1
+ beg_idx = beg_idx + tmp_sample_num
+ return data_idx_order_list
+
+ def get_img_data(self, value):
+ """get_img_data."""
+ if not value:
+ return None
+ imgdata = np.frombuffer(value, dtype='uint8')
+ if imgdata is None:
+ return None
+ imgori = cv2.imdecode(imgdata, 1)
+ if imgori is None:
+ return None
+ return imgori
+
+ def get_ext_data(self):
+ ext_data_num = 0
+ for op in self.ops:
+ if hasattr(op, 'ext_data_num'):
+ ext_data_num = getattr(op, 'ext_data_num')
+ break
+ load_data_ops = self.ops[:self.ext_op_transform_idx]
+ ext_data = []
+
+ while len(ext_data) < ext_data_num:
+ lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(
+ len(self))]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ sample_info = self.get_lmdb_sample_info(
+ self.lmdb_sets[lmdb_idx]['txn'], file_idx)
+ if sample_info is None:
+ continue
+ img, label = sample_info
+ data = {'image': img, 'label': label}
+ data = transform(data, load_data_ops)
+ if data is None:
+ continue
+ ext_data.append(data)
+ return ext_data
+
+ def get_lmdb_sample_info(self, txn, index):
+ label_key = 'label-%09d'.encode() % index
+ label = txn.get(label_key)
+ if label is None:
+ return None
+ label = label.decode('utf-8')
+ img_key = 'image-%09d'.encode() % index
+ imgbuf = txn.get(img_key)
+ return imgbuf, label
+
+ def __getitem__(self, idx):
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ sample_info = self.get_lmdb_sample_info(
+ self.lmdb_sets[lmdb_idx]['txn'], file_idx)
+ if sample_info is None:
+ return self.__getitem__(np.random.randint(self.__len__()))
+ img, label = sample_info
+ data = {'image': img, 'label': label}
+ data['ext_data'] = self.get_ext_data()
+ outs = transform(data, self.ops)
+ if outs is None:
+ return self.__getitem__(np.random.randint(self.__len__()))
+ return outs
+
+ def __len__(self):
+ return self.data_idx_order_list.shape[0]
diff --git a/tools/engine/__init__.py b/tools/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a54e7502290649fd9ebd57e4dc9c4a9e55dc7ef5
--- /dev/null
+++ b/tools/engine/__init__.py
@@ -0,0 +1,5 @@
+from . import config, trainer
+from .config import *
+from .trainer import *
+
+__all__ = config.__all__ + trainer.__all__
diff --git a/tools/engine/__pycache__/__init__.cpython-38.pyc b/tools/engine/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f677a26af327738b636c2d87626b7608a3d34307
Binary files /dev/null and b/tools/engine/__pycache__/__init__.cpython-38.pyc differ
diff --git a/tools/engine/__pycache__/config.cpython-38.pyc b/tools/engine/__pycache__/config.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cad91bed268d7f028f0163cc8d624de4258784d8
Binary files /dev/null and b/tools/engine/__pycache__/config.cpython-38.pyc differ
diff --git a/tools/engine/__pycache__/trainer.cpython-38.pyc b/tools/engine/__pycache__/trainer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0f215df15c12f00351fe0616cea2ea4866be809
Binary files /dev/null and b/tools/engine/__pycache__/trainer.cpython-38.pyc differ
diff --git a/tools/engine/config.py b/tools/engine/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d16dd21c46b11d76961fb689f0ac2b2c449e77c7
--- /dev/null
+++ b/tools/engine/config.py
@@ -0,0 +1,158 @@
+import os
+from argparse import ArgumentParser, RawDescriptionHelpFormatter
+from collections.abc import Mapping
+
+import yaml
+
+__all__ = ['Config']
+
+
+class ArgsParser(ArgumentParser):
+
+ def __init__(self):
+ super(ArgsParser,
+ self).__init__(formatter_class=RawDescriptionHelpFormatter)
+ self.add_argument('-o',
+ '--opt',
+ nargs='*',
+ help='set configuration options')
+ self.add_argument('--local_rank')
+
+ def parse_args(self, argv=None):
+ args = super(ArgsParser, self).parse_args(argv)
+ assert args.config is not None, 'Please specify --config=configure_file_path.'
+ args.opt = self._parse_opt(args.opt)
+ return args
+
+ def _parse_opt(self, opts):
+ config = {}
+ if not opts:
+ return config
+ for s in opts:
+ s = s.strip()
+ k, v = s.split('=', 1)
+ if '.' not in k:
+ config[k] = yaml.load(v, Loader=yaml.Loader)
+ else:
+ keys = k.split('.')
+ if keys[0] not in config:
+ config[keys[0]] = {}
+ cur = config[keys[0]]
+ for idx, key in enumerate(keys[1:]):
+ if idx == len(keys) - 2:
+ cur[key] = yaml.load(v, Loader=yaml.Loader)
+ else:
+ cur[key] = {}
+ cur = cur[key]
+ return config
+
+
+class AttrDict(dict):
+ """Single level attribute dict, NOT recursive."""
+
+ def __init__(self, **kwargs):
+ super(AttrDict, self).__init__()
+ super(AttrDict, self).update(kwargs)
+
+ def __getattr__(self, key):
+ if key in self:
+ return self[key]
+ raise AttributeError("object has no attribute '{}'".format(key))
+
+
+def _merge_dict(config, merge_dct):
+ """Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
+ updating only top-level keys, dict_merge recurses down into dicts nested to
+ an arbitrary depth, updating keys. The ``merge_dct`` is merged into
+ ``dct``.
+
+ Args:
+ config: dict onto which the merge is executed
+ merge_dct: dct merged into config
+
+ Returns: dct
+ """
+ for key, value in merge_dct.items():
+ sub_keys = key.split('.')
+ key = sub_keys[0]
+ if key in config and len(sub_keys) > 1:
+ _merge_dict(config[key], {'.'.join(sub_keys[1:]): value})
+ elif key in config and isinstance(config[key], dict) and isinstance(
+ value, Mapping):
+ _merge_dict(config[key], value)
+ else:
+ config[key] = value
+ return config
+
+
+def print_dict(cfg, print_func=print, delimiter=0):
+ """Recursively visualize a dict and indenting acrrording by the
+ relationship of keys."""
+ for k, v in sorted(cfg.items()):
+ if isinstance(v, dict):
+ print_func('{}{} : '.format(delimiter * ' ', str(k)))
+ print_dict(v, print_func, delimiter + 4)
+ elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
+ print_func('{}{} : '.format(delimiter * ' ', str(k)))
+ for value in v:
+ print_dict(value, print_func, delimiter + 4)
+ else:
+ print_func('{}{} : {}'.format(delimiter * ' ', k, v))
+
+
+class Config(object):
+
+ def __init__(self, config_path, BASE_KEY='_BASE_'):
+ self.BASE_KEY = BASE_KEY
+ self.cfg = self._load_config_with_base(config_path)
+
+ def _load_config_with_base(self, file_path):
+ """Load config from file.
+
+ Args:
+ file_path (str): Path of the config file to be loaded.
+
+ Returns: global config
+ """
+ _, ext = os.path.splitext(file_path)
+ assert ext in ['.yml', '.yaml'], 'only support yaml files for now'
+
+ with open(file_path) as f:
+ file_cfg = yaml.load(f, Loader=yaml.Loader)
+
+ # NOTE: cfgs outside have higher priority than cfgs in _BASE_
+ if self.BASE_KEY in file_cfg:
+ all_base_cfg = AttrDict()
+ base_ymls = list(file_cfg[self.BASE_KEY])
+ for base_yml in base_ymls:
+ if base_yml.startswith('~'):
+ base_yml = os.path.expanduser(base_yml)
+ if not base_yml.startswith('/'):
+ base_yml = os.path.join(os.path.dirname(file_path),
+ base_yml)
+
+ with open(base_yml) as f:
+ base_cfg = self._load_config_with_base(base_yml)
+ all_base_cfg = _merge_dict(all_base_cfg, base_cfg)
+
+ del file_cfg[self.BASE_KEY]
+ file_cfg = _merge_dict(all_base_cfg, file_cfg)
+ file_cfg['filename'] = os.path.splitext(
+ os.path.split(file_path)[-1])[0]
+ return file_cfg
+
+ def merge_dict(self, args):
+ self.cfg = _merge_dict(self.cfg, args)
+
+ def print_cfg(self, print_func=print):
+ """Recursively visualize a dict and indenting acrrording by the
+ relationship of keys."""
+ print_func('----------- Config -----------')
+ print_dict(self.cfg, print_func)
+ print_func('---------------------------------------------')
+
+ def save(self, p, cfg=None):
+ if cfg is None:
+ cfg = self.cfg
+ with open(p, 'w') as f:
+ yaml.dump(dict(cfg), f, default_flow_style=False, sort_keys=False)
diff --git a/tools/engine/trainer.py b/tools/engine/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2382b4552b97fa85115a56ae8a5d7d1e5539178
--- /dev/null
+++ b/tools/engine/trainer.py
@@ -0,0 +1,621 @@
+import copy
+import datetime
+import os
+import random
+import time
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from openrec.losses import build_loss
+from openrec.metrics import build_metric
+from openrec.modeling import build_model
+from openrec.optimizer import build_optimizer
+from openrec.postprocess import build_post_process
+from tools.data import build_dataloader
+from tools.utils.ckpt import load_ckpt, save_ckpt
+from tools.utils.logging import get_logger
+from tools.utils.stats import TrainingStats
+from tools.utils.utility import AverageMeter
+
+__all__ = ['Trainer']
+
+
+def get_parameter_number(model):
+ total_num = sum(p.numel() for p in model.parameters())
+ trainable_num = sum(p.numel() for p in model.parameters()
+ if p.requires_grad)
+ return {'Total': total_num, 'Trainable': trainable_num}
+
+
+class Trainer(object):
+
+ def __init__(self, cfg, mode='train'):
+ self.cfg = cfg.cfg
+
+ self.local_rank = (int(os.environ['LOCAL_RANK'])
+ if 'LOCAL_RANK' in os.environ else 0)
+ self.set_device(self.cfg['Global']['device'])
+ mode = mode.lower()
+ assert mode in [
+ 'train_eval',
+ 'train',
+ 'eval',
+ 'test',
+ ], 'mode should be train, eval and test'
+ if torch.cuda.device_count() > 1 and 'train' in mode:
+ torch.distributed.init_process_group(backend='nccl')
+ torch.cuda.set_device(self.device)
+ self.cfg['Global']['distributed'] = True
+ else:
+ self.cfg['Global']['distributed'] = False
+ self.local_rank = 0
+
+ self.cfg['Global']['output_dir'] = self.cfg['Global'].get(
+ 'output_dir', 'output')
+ os.makedirs(self.cfg['Global']['output_dir'], exist_ok=True)
+
+ self.writer = None
+ if self.local_rank == 0 and self.cfg['Global'][
+ 'use_tensorboard'] and 'train' in mode:
+ from torch.utils.tensorboard import SummaryWriter
+
+ self.writer = SummaryWriter(self.cfg['Global']['output_dir'])
+
+ self.logger = get_logger(
+ 'openrec',
+ os.path.join(self.cfg['Global']['output_dir'], 'train.log')
+ if 'train' in mode else None,
+ )
+
+ cfg.print_cfg(self.logger.info)
+
+ if self.cfg['Global']['device'] == 'gpu' and self.device.type == 'cpu':
+ self.logger.info('cuda is not available, auto switch to cpu')
+
+ self.grad_clip_val = self.cfg['Global'].get('grad_clip_val', 0)
+ self.all_ema = self.cfg['Global'].get('all_ema', True)
+ self.use_ema = self.cfg['Global'].get('use_ema', True)
+
+ self.set_random_seed(self.cfg['Global'].get('seed', 48))
+
+ # build data loader
+ self.train_dataloader = None
+ if 'train' in mode:
+ cfg.save(
+ os.path.join(self.cfg['Global']['output_dir'], 'config.yml'),
+ self.cfg)
+ self.train_dataloader = build_dataloader(self.cfg, 'Train',
+ self.logger)
+ self.logger.info(
+ f'train dataloader has {len(self.train_dataloader)} iters')
+ self.valid_dataloader = None
+ if 'eval' in mode and self.cfg['Eval']:
+ self.valid_dataloader = build_dataloader(self.cfg, 'Eval',
+ self.logger)
+ self.logger.info(
+ f'valid dataloader has {len(self.valid_dataloader)} iters')
+
+ # build post process
+ self.post_process_class = build_post_process(self.cfg['PostProcess'],
+ self.cfg['Global'])
+ # build model
+ # for rec algorithm
+ char_num = self.post_process_class.get_character_num()
+ self.cfg['Architecture']['Decoder']['out_channels'] = char_num
+
+ self.model = build_model(self.cfg['Architecture'])
+ self.logger.info(get_parameter_number(model=self.model))
+ self.model = self.model.to(self.device)
+
+ if self.local_rank == 0:
+ ema_model = build_model(self.cfg['Architecture'])
+ self.ema_model = ema_model.to(self.device)
+ self.ema_model.eval()
+
+ use_sync_bn = self.cfg['Global'].get('use_sync_bn', False)
+ if use_sync_bn:
+ self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
+ self.model)
+ self.logger.info('convert_sync_batchnorm')
+
+ # build loss
+ self.loss_class = build_loss(self.cfg['Loss'])
+
+ self.optimizer, self.lr_scheduler = None, None
+ if self.train_dataloader is not None:
+ # build optim
+ self.optimizer, self.lr_scheduler = build_optimizer(
+ self.cfg['Optimizer'],
+ self.cfg['LRScheduler'],
+ epochs=self.cfg['Global']['epoch_num'],
+ step_each_epoch=len(self.train_dataloader),
+ model=self.model,
+ )
+
+ self.eval_class = build_metric(self.cfg['Metric'])
+
+ self.status = load_ckpt(self.model, self.cfg, self.optimizer,
+ self.lr_scheduler)
+
+ if self.cfg['Global']['distributed']:
+ self.model = torch.nn.parallel.DistributedDataParallel(
+ self.model, [self.local_rank], find_unused_parameters=False)
+
+ # amp
+ self.scaler = (torch.cuda.amp.GradScaler() if self.cfg['Global'].get(
+ 'use_amp', False) else None)
+
+ self.logger.info(
+ f'run with torch {torch.__version__} and device {self.device}')
+
+ def load_params(self, params):
+ self.model.load_state_dict(params)
+
+ def set_random_seed(self, seed):
+ torch.manual_seed(seed) # 为CPU设置随机种子
+ if self.device.type == 'cuda':
+ torch.backends.cudnn.benchmark = True
+ torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
+ torch.cuda.manual_seed_all(seed) # 为所有GPU设置随机种子
+ random.seed(seed)
+ np.random.seed(seed)
+
+ def set_device(self, device):
+ if device == 'gpu' and torch.cuda.is_available():
+ device = torch.device(f'cuda:{self.local_rank}')
+ else:
+ device = torch.device('cpu')
+ self.device = device
+
+ def train(self):
+ cal_metric_during_train = self.cfg['Global'].get(
+ 'cal_metric_during_train', False)
+ log_smooth_window = self.cfg['Global']['log_smooth_window']
+ epoch_num = self.cfg['Global']['epoch_num']
+ print_batch_step = self.cfg['Global']['print_batch_step']
+ eval_epoch_step = self.cfg['Global'].get('eval_epoch_step', 1)
+
+ start_eval_epoch = 0
+ if self.valid_dataloader is not None:
+ if type(eval_epoch_step) == list and len(eval_epoch_step) >= 2:
+ start_eval_epoch = eval_epoch_step[0]
+ eval_epoch_step = eval_epoch_step[1]
+ if len(self.valid_dataloader) == 0:
+ start_eval_epoch = 1e111
+ self.logger.info(
+ 'No Images in eval dataset, evaluation during training will be disabled'
+ )
+ self.logger.info(
+ f'During the training process, after the {start_eval_epoch}th epoch, '
+ f'an evaluation is run every {eval_epoch_step} epoch')
+ else:
+ start_eval_epoch = 1e111
+
+ eval_batch_step = self.cfg['Global']['eval_batch_step']
+
+ global_step = self.status.get('global_step', 0)
+
+ start_eval_step = 0
+ if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
+ start_eval_step = eval_batch_step[0]
+ eval_batch_step = eval_batch_step[1]
+ if len(self.valid_dataloader) == 0:
+ self.logger.info(
+ 'No Images in eval dataset, evaluation during training '
+ 'will be disabled')
+ start_eval_step = 1e111
+ self.logger.info(
+ 'During the training process, after the {}th iteration, '
+ 'an evaluation is run every {} iterations'.format(
+ start_eval_step, eval_batch_step))
+
+ start_epoch = self.status.get('epoch', 1)
+ best_metric = self.status.get('metrics', {})
+ if self.eval_class.main_indicator not in best_metric:
+ best_metric[self.eval_class.main_indicator] = 0
+ ema_best_metric = self.status.get('metrics', {})
+ ema_best_metric[self.eval_class.main_indicator] = 0
+ train_stats = TrainingStats(log_smooth_window, ['lr'])
+ self.model.train()
+
+ total_samples = 0
+ train_reader_cost = 0.0
+ train_batch_cost = 0.0
+ best_iter = 0
+ ema_stpe = 1
+ ema_eval_iter = 0
+ loss_avg = 0.
+ reader_start = time.time()
+ eta_meter = AverageMeter()
+
+ for epoch in range(start_epoch, epoch_num + 1):
+ if self.train_dataloader.dataset.need_reset:
+ self.train_dataloader = build_dataloader(
+ self.cfg,
+ 'Train',
+ self.logger,
+ epoch=epoch % 20 if epoch % 20 != 0 else 20,
+ )
+
+ for idx, batch in enumerate(self.train_dataloader):
+ batch = [t.to(self.device) for t in batch]
+ self.optimizer.zero_grad()
+ train_reader_cost += time.time() - reader_start
+ # use amp
+ if self.scaler:
+ with torch.cuda.amp.autocast():
+ preds = self.model(batch[0], data=batch[1:])
+ loss = self.loss_class(preds, batch)
+ self.scaler.scale(loss['loss']).backward()
+ if self.grad_clip_val > 0:
+ torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(),
+ max_norm=self.grad_clip_val)
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ else:
+ preds = self.model(batch[0], data=batch[1:])
+ loss = self.loss_class(preds, batch)
+ avg_loss = loss['loss']
+ avg_loss.backward()
+ if self.grad_clip_val > 0:
+ torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(),
+ max_norm=self.grad_clip_val)
+ self.optimizer.step()
+
+ if cal_metric_during_train: # only rec and cls need
+ post_result = self.post_process_class(preds,
+ batch,
+ training=True)
+ self.eval_class(post_result, batch, training=True)
+ metric = self.eval_class.get_metric()
+ train_stats.update(metric)
+
+ train_batch_time = time.time() - reader_start
+ train_batch_cost += train_batch_time
+ eta_meter.update(train_batch_time)
+ global_step += 1
+ total_samples += len(batch[0])
+
+ self.lr_scheduler.step()
+
+ if self.local_rank == 0 and self.use_ema and epoch > (
+ epoch_num - epoch_num // 10):
+ with torch.no_grad():
+ loss_currn = loss['loss'].detach().cpu().numpy().mean()
+ loss_avg = ((loss_avg *
+ (ema_stpe - 1)) + loss_currn) / (ema_stpe)
+ if ema_stpe == 1:
+
+ # current_weight = copy.deepcopy(self.model.module.state_dict())
+ ema_state_dict = copy.deepcopy(
+ self.model.module.state_dict() if self.
+ cfg['Global']['distributed'] else self.model.
+ state_dict())
+ self.ema_model.load_state_dict(ema_state_dict)
+ # if global_step > (epoch_num - epoch_num//10)*max_iter:
+ elif loss_currn <= loss_avg or self.all_ema:
+ # eval_batch_step = 500
+ current_weight = copy.deepcopy(
+ self.model.module.state_dict() if self.
+ cfg['Global']['distributed'] else self.model.
+ state_dict())
+ k1 = 1 / (ema_stpe + 1)
+ k2 = 1 - k1
+ for k, v in ema_state_dict.items():
+ # v = (v * (ema_stpe - 1) + current_weight[k])/ema_stpe
+ v = v * k2 + current_weight[k] * k1
+ # v.req = True
+ ema_state_dict[k] = v
+ # ema_stpe += 1
+ self.ema_model.load_state_dict(ema_state_dict)
+ ema_stpe += 1
+ if global_step > start_eval_step and (
+ global_step -
+ start_eval_step) % eval_batch_step == 0:
+ ema_cur_metric = self.eval_ema()
+ ema_cur_metric_str = f"cur ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_cur_metric.items()])}"
+ self.logger.info(ema_cur_metric_str)
+ state = {
+ 'epoch': epoch,
+ 'global_step': global_step,
+ 'state_dict': self.ema_model.state_dict(),
+ 'optimizer': None,
+ 'scheduler': None,
+ 'config': self.cfg,
+ 'metrics': ema_cur_metric,
+ }
+ save_path = os.path.join(
+ self.cfg['Global']['output_dir'],
+ 'ema_' + str(ema_eval_iter) + '.pth')
+ torch.save(state, save_path)
+ self.logger.info(f'save ema ckpt to {save_path}')
+ ema_eval_iter += 1
+ if ema_cur_metric[self.eval_class.
+ main_indicator] >= ema_best_metric[
+ self.eval_class.main_indicator]:
+ ema_best_metric.update(ema_cur_metric)
+ ema_best_metric['best_epoch'] = epoch
+ best_ema_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}"
+ self.logger.info(best_ema_str)
+
+ # logger
+ stats = {
+ k: float(v)
+ if v.shape == [] else v.detach().cpu().numpy().mean()
+ for k, v in loss.items()
+ }
+ stats['lr'] = self.lr_scheduler.get_last_lr()[0]
+ train_stats.update(stats)
+
+ if self.writer is not None:
+ for k, v in train_stats.get().items():
+ self.writer.add_scalar(f'TRAIN/{k}', v, global_step)
+
+ if self.local_rank == 0 and (
+ (global_step > 0 and global_step % print_batch_step == 0)
+ or (idx >= len(self.train_dataloader) - 1)):
+ logs = train_stats.log()
+
+ eta_sec = (
+ (epoch_num + 1 - epoch) * len(self.train_dataloader) -
+ idx - 1) * eta_meter.avg
+ eta_sec_format = str(
+ datetime.timedelta(seconds=int(eta_sec)))
+ strs = (
+ f'epoch: [{epoch}/{epoch_num}], global_step: {global_step}, {logs}, '
+ f'avg_reader_cost: {train_reader_cost / print_batch_step:.5f} s, '
+ f'avg_batch_cost: {train_batch_cost / print_batch_step:.5f} s, '
+ f'avg_samples: {total_samples / print_batch_step}, '
+ f'ips: {total_samples / train_batch_cost:.5f} samples/s, '
+ f'eta: {eta_sec_format}')
+ self.logger.info(strs)
+ total_samples = 0
+ train_reader_cost = 0.0
+ train_batch_cost = 0.0
+ reader_start = time.time()
+ # eval
+ if (global_step > start_eval_step and
+ (global_step - start_eval_step) % eval_batch_step
+ == 0) and self.local_rank == 0:
+ cur_metric = self.eval()
+ cur_metric_str = f"cur metric, {', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])}"
+ self.logger.info(cur_metric_str)
+
+ # logger metric
+ if self.writer is not None:
+ for k, v in cur_metric.items():
+ if isinstance(v, (float, int)):
+ self.writer.add_scalar(f'EVAL/{k}',
+ cur_metric[k],
+ global_step)
+
+ if (cur_metric[self.eval_class.main_indicator] >=
+ best_metric[self.eval_class.main_indicator]):
+ best_metric.update(cur_metric)
+ best_metric['best_epoch'] = epoch
+ if self.writer is not None:
+ self.writer.add_scalar(
+ f'EVAL/best_{self.eval_class.main_indicator}',
+ best_metric[self.eval_class.main_indicator],
+ global_step,
+ )
+ if epoch > (epoch_num - epoch_num // 10 - 2):
+ save_ckpt(self.model,
+ self.cfg,
+ self.optimizer,
+ self.lr_scheduler,
+ epoch,
+ global_step,
+ best_metric,
+ is_best=True,
+ prefix='best_' + str(best_iter))
+ best_iter += 1
+ # else:
+ save_ckpt(self.model,
+ self.cfg,
+ self.optimizer,
+ self.lr_scheduler,
+ epoch,
+ global_step,
+ best_metric,
+ is_best=True,
+ prefix=None)
+ best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}"
+ self.logger.info(best_str)
+ if self.local_rank == 0 and epoch > start_eval_epoch and (
+ epoch - start_eval_epoch) % eval_epoch_step == 0:
+ cur_metric = self.eval()
+ cur_metric_str = f"cur metric, {', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])}"
+ self.logger.info(cur_metric_str)
+
+ # logger metric
+ if self.writer is not None:
+ for k, v in cur_metric.items():
+ if isinstance(v, (float, int)):
+ self.writer.add_scalar(f'EVAL/{k}', cur_metric[k],
+ global_step)
+
+ if (cur_metric[self.eval_class.main_indicator] >=
+ best_metric[self.eval_class.main_indicator]):
+ best_metric.update(cur_metric)
+ best_metric['best_epoch'] = epoch
+ if self.writer is not None:
+ self.writer.add_scalar(
+ f'EVAL/best_{self.eval_class.main_indicator}',
+ best_metric[self.eval_class.main_indicator],
+ global_step,
+ )
+ if epoch > (epoch_num - epoch_num // 10 - 2):
+ save_ckpt(self.model,
+ self.cfg,
+ self.optimizer,
+ self.lr_scheduler,
+ epoch,
+ global_step,
+ best_metric,
+ is_best=True,
+ prefix='best_' + str(best_iter))
+ best_iter += 1
+ # else:
+ save_ckpt(self.model,
+ self.cfg,
+ self.optimizer,
+ self.lr_scheduler,
+ epoch,
+ global_step,
+ best_metric,
+ is_best=True,
+ prefix=None)
+ best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}"
+ self.logger.info(best_str)
+
+ if self.local_rank == 0:
+ save_ckpt(self.model,
+ self.cfg,
+ self.optimizer,
+ self.lr_scheduler,
+ epoch,
+ global_step,
+ best_metric,
+ is_best=False,
+ prefix=None)
+ if epoch > (epoch_num - epoch_num // 10 - 2):
+ save_ckpt(self.model,
+ self.cfg,
+ self.optimizer,
+ self.lr_scheduler,
+ epoch,
+ global_step,
+ best_metric,
+ is_best=False,
+ prefix='epoch_' + str(epoch))
+ if self.use_ema and epoch > (epoch_num - epoch_num // 10):
+ # if global_step > start_eval_step and (global_step - start_eval_step) % eval_batch_step == 0:
+ ema_cur_metric = self.eval_ema()
+ ema_cur_metric_str = f"cur ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_cur_metric.items()])}"
+ self.logger.info(ema_cur_metric_str)
+ state = {
+ 'epoch': epoch,
+ 'global_step': global_step,
+ 'state_dict': self.ema_model.state_dict(),
+ 'optimizer': None,
+ 'scheduler': None,
+ 'config': self.cfg,
+ 'metrics': ema_cur_metric,
+ }
+ save_path = os.path.join(
+ self.cfg['Global']['output_dir'],
+ 'ema_' + str(ema_eval_iter) + '.pth')
+ torch.save(state, save_path)
+ self.logger.info(f'save ema ckpt to {save_path}')
+ ema_eval_iter += 1
+ if (ema_cur_metric[self.eval_class.main_indicator] >=
+ ema_best_metric[self.eval_class.main_indicator]):
+ ema_best_metric.update(ema_cur_metric)
+ ema_best_metric['best_epoch'] = epoch
+ # ema_cur_metric_str = f"best ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}"
+ best_ema_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}"
+ self.logger.info(best_ema_str)
+ best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}"
+ self.logger.info(best_str)
+ if self.writer is not None:
+ self.writer.close()
+ if torch.cuda.device_count() > 1:
+ torch.distributed.destroy_process_group()
+
+ def eval(self):
+ self.model.eval()
+ with torch.no_grad():
+ total_frame = 0.0
+ total_time = 0.0
+ pbar = tqdm(
+ total=len(self.valid_dataloader),
+ desc='eval model:',
+ position=0,
+ leave=True,
+ )
+ sum_images = 0
+ for idx, batch in enumerate(self.valid_dataloader):
+ batch = [t.to(self.device) for t in batch]
+ start = time.time()
+ if self.scaler:
+ with torch.cuda.amp.autocast():
+ preds = self.model(batch[0], data=batch[1:])
+ else:
+ preds = self.model(batch[0], data=batch[1:])
+
+ total_time += time.time() - start
+ # Obtain usable results from post-processing methods
+ # Evaluate the results of the current batch
+ post_result = self.post_process_class(preds, batch)
+ self.eval_class(post_result, batch)
+
+ pbar.update(1)
+ total_frame += len(batch[0])
+ sum_images += 1
+ # Get final metric,eg. acc or hmean
+ metric = self.eval_class.get_metric()
+
+ pbar.close()
+ self.model.train()
+ metric['fps'] = total_frame / total_time
+ return metric
+
+ def eval_ema(self):
+ # self.model.eval()
+ with torch.no_grad():
+ total_frame = 0.0
+ total_time = 0.0
+ pbar = tqdm(
+ total=len(self.valid_dataloader),
+ desc='eval ema_model:',
+ position=0,
+ leave=True,
+ )
+ sum_images = 0
+ for idx, batch in enumerate(self.valid_dataloader):
+ batch = [t.to(self.device) for t in batch]
+ start = time.time()
+ if self.scaler:
+ with torch.cuda.amp.autocast():
+ preds = self.ema_model(batch[0], data=batch[1:])
+ else:
+ preds = self.ema_model(batch[0], data=batch[1:])
+
+ total_time += time.time() - start
+ # Obtain usable results from post-processing methods
+ # Evaluate the results of the current batch
+ post_result = self.post_process_class(preds, batch)
+ self.eval_class(post_result, batch)
+
+ pbar.update(1)
+ total_frame += len(batch[0])
+ sum_images += 1
+ # Get final metric,eg. acc or hmean
+ metric = self.eval_class.get_metric()
+
+ pbar.close()
+ # self.model.train()
+ metric['fps'] = total_frame / total_time
+ return metric
+
+ def test_dataloader(self):
+ starttime = time.time()
+ count = 0
+ try:
+ for data in self.train_dataloader:
+ count += 1
+ if count % 1 == 0:
+ batch_time = time.time() - starttime
+ starttime = time.time()
+ self.logger.info(
+ f'reader: {count}, {data[0].shape}, {batch_time}')
+ except:
+ import traceback
+
+ self.logger.info(traceback.format_exc())
+ self.logger.info(f'finish reader: {count}, Success!')
diff --git a/tools/eval_rec.py b/tools/eval_rec.py
new file mode 100644
index 0000000000000000000000000000000000000000..649c20cc11aa2969a231e60118c8b1866745fbec
--- /dev/null
+++ b/tools/eval_rec.py
@@ -0,0 +1,41 @@
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+from tools.engine import Config, Trainer
+from tools.utility import ArgsParser
+
+
+def parse_args():
+ parser = ArgsParser()
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ FLAGS = parse_args()
+ cfg = Config(FLAGS.config)
+ FLAGS = vars(FLAGS)
+ opt = FLAGS.pop('opt')
+ cfg.merge_dict(FLAGS)
+ cfg.merge_dict(opt)
+ trainer = Trainer(cfg, mode='eval')
+
+ best_model_dict = trainer.status.get('metrics', {})
+ trainer.logger.info('metric in ckpt ***************')
+ for k, v in best_model_dict.items():
+ trainer.logger.info('{}:{}'.format(k, v))
+
+ metric = trainer.eval()
+
+ trainer.logger.info('metric eval ***************')
+ for k, v in metric.items():
+ trainer.logger.info('{}:{}'.format(k, v))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/eval_rec_all_ch.py b/tools/eval_rec_all_ch.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b7671bc5964c15daf42e313e2f37e2d29a2804e
--- /dev/null
+++ b/tools/eval_rec_all_ch.py
@@ -0,0 +1,184 @@
+import csv
+import os
+import sys
+import numpy as np
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+from tools.data import build_dataloader
+from tools.engine import Config, Trainer
+from tools.utility import ArgsParser
+
+
+def parse_args():
+ parser = ArgsParser()
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ FLAGS = parse_args()
+ cfg = Config(FLAGS.config)
+ FLAGS = vars(FLAGS)
+ opt = FLAGS.pop('opt')
+ cfg.merge_dict(FLAGS)
+ cfg.merge_dict(opt)
+ msr = False
+ if 'RatioDataSet' in cfg.cfg['Eval']['dataset']['name']:
+ msr = True
+
+ if cfg.cfg['Global']['output_dir'][-1] == '/':
+ cfg.cfg['Global']['output_dir'] = cfg.cfg['Global']['output_dir'][:-1]
+ if cfg.cfg['Global']['pretrained_model'] is None:
+ cfg.cfg['Global'][
+ 'pretrained_model'] = cfg.cfg['Global']['output_dir'] + '/best.pth'
+ cfg.cfg['Global']['use_amp'] = False
+ cfg.cfg['PostProcess']['with_ratio'] = True
+ cfg.cfg['Metric']['with_ratio'] = True
+ cfg.cfg['Metric']['max_len'] = 25
+ cfg.cfg['Metric']['max_ratio'] = 12
+ cfg.cfg['Eval']['dataset']['transforms'][-1]['KeepKeys'][
+ 'keep_keys'].append('real_ratio')
+ trainer = Trainer(cfg, mode='eval')
+
+ best_model_dict = trainer.status.get('metrics', {})
+ trainer.logger.info('metric in ckpt ***************')
+ for k, v in best_model_dict.items():
+ trainer.logger.info('{}:{}'.format(k, v))
+
+ data_dirs_list = [[
+ '../benchmark_bctr/benchmark_bctr_test/scene_test',
+ '../benchmark_bctr/benchmark_bctr_test/web_test',
+ '../benchmark_bctr/benchmark_bctr_test/document_test',
+ '../benchmark_bctr/benchmark_bctr_test/handwriting_test'
+ ]]
+ cfg = cfg.cfg
+ file_csv = open(
+ cfg['Global']['output_dir'] + '/' +
+ cfg['Global']['output_dir'].split('/')[-1] +
+ '_eval_all_ch_length_ratio.csv', 'w')
+ csv_w = csv.writer(file_csv)
+
+ for data_dirs in data_dirs_list:
+
+ acc_each = []
+ acc_each_real = []
+ acc_each_ingore_space = []
+ acc_each_ignore_space_symbol = []
+ acc_each_lower_ignore_space_symbol = []
+ acc_each_num = []
+ acc_each_dis = []
+ each_len = {}
+ each_ratio = {}
+ for datadir in data_dirs:
+ config_each = cfg.copy()
+ if msr:
+ config_each['Eval']['dataset']['data_dir_list'] = [datadir]
+ else:
+ config_each['Eval']['dataset']['data_dir'] = datadir
+ # config_each['Eval']['dataset']['label_file_list']=[label_file_list]
+ valid_dataloader = build_dataloader(config_each, 'Eval',
+ trainer.logger)
+ trainer.logger.info(
+ f'{datadir} valid dataloader has {len(valid_dataloader)} iters'
+ )
+ # valid_dataloaders.append(valid_dataloader)
+ trainer.valid_dataloader = valid_dataloader
+ metric = trainer.eval()
+ acc_each.append(metric['acc'] * 100)
+ acc_each_real.append(metric['acc_real'] * 100)
+ acc_each_ingore_space.append(metric['acc_ignore_space'] * 100)
+ acc_each_ignore_space_symbol.append(
+ metric['acc_ignore_space_symbol'] * 100)
+ acc_each_lower_ignore_space_symbol.append(
+ metric['acc_lower_ignore_space_symbol'] * 100)
+ acc_each_dis.append(metric['norm_edit_dis'])
+ acc_each_num.append(metric['num_samples'])
+
+ trainer.logger.info('metric eval ***************')
+ csv_w.writerow([datadir])
+ for k, v in metric.items():
+ trainer.logger.info('{}:{}'.format(k, v))
+ if 'each' in k:
+ csv_w.writerow([k] + v)
+ if 'each_len' in k:
+ each_len[k] = each_len.get(k, []) + [np.array(v)]
+ if 'each_ratio' in k:
+ each_ratio[k] = each_ratio.get(k, []) + [np.array(v)]
+ data_name = [
+ data_n[:-1].split('/')[-1]
+ if data_n[-1] == '/' else data_n.split('/')[-1]
+ for data_n in data_dirs
+ ]
+ csv_w.writerow(['-'] + data_name + ['arithmetic_avg'] +
+ ['weighted_avg'])
+ csv_w.writerow([''] + acc_each_num)
+ avg1 = np.array(acc_each) * np.array(acc_each_num) / sum(acc_each_num)
+ csv_w.writerow(['acc'] + acc_each + [sum(acc_each) / len(acc_each)] +
+ [avg1.sum().tolist()])
+ print(acc_each + [sum(acc_each) / len(acc_each)] +
+ [avg1.sum().tolist()])
+ avg1 = np.array(acc_each_dis) * np.array(acc_each_num) / sum(
+ acc_each_num)
+ csv_w.writerow(['norm_edit_dis'] + acc_each_dis +
+ [sum(acc_each_dis) / len(acc_each)] +
+ [avg1.sum().tolist()])
+
+ avg1 = np.array(acc_each_real) * np.array(acc_each_num) / sum(
+ acc_each_num)
+ csv_w.writerow(['acc_real'] + acc_each_real +
+ [sum(acc_each_real) / len(acc_each_real)] +
+ [avg1.sum().tolist()])
+ avg1 = np.array(acc_each_ingore_space) * np.array(acc_each_num) / sum(
+ acc_each_num)
+ csv_w.writerow(
+ ['acc_ignore_space'] + acc_each_ingore_space +
+ [sum(acc_each_ingore_space) / len(acc_each_ingore_space)] +
+ [avg1.sum().tolist()])
+ avg1 = np.array(acc_each_ignore_space_symbol) * np.array(
+ acc_each_num) / sum(acc_each_num)
+ csv_w.writerow(['acc_ignore_space_symbol'] +
+ acc_each_ignore_space_symbol + [
+ sum(acc_each_ignore_space_symbol) /
+ len(acc_each_ignore_space_symbol)
+ ] + [avg1.sum().tolist()])
+ avg1 = np.array(acc_each_lower_ignore_space_symbol) * np.array(
+ acc_each_num) / sum(acc_each_num)
+ csv_w.writerow(['acc_lower_ignore_space_symbol'] +
+ acc_each_lower_ignore_space_symbol + [
+ sum(acc_each_lower_ignore_space_symbol) /
+ len(acc_each_lower_ignore_space_symbol)
+ ] + [avg1.sum().tolist()])
+
+ sum_all = np.array(each_len['each_len_num']).sum(0)
+ for k, v in each_len.items():
+ if k != 'each_len_num':
+ v_sum_weight = (np.array(v) *
+ np.array(each_len['each_len_num'])).sum(0)
+ sum_all_pad = np.where(sum_all == 0, 1., sum_all)
+ v_all = v_sum_weight / sum_all_pad
+ v_all = np.where(sum_all == 0, 0., v_all)
+ csv_w.writerow([k] + v_all.tolist())
+ else:
+ csv_w.writerow([k] + sum_all.tolist())
+
+ sum_all = np.array(each_ratio['each_ratio_num']).sum(0)
+ for k, v in each_ratio.items():
+ if k != 'each_ratio_num':
+ v_sum_weight = (np.array(v) *
+ np.array(each_ratio['each_ratio_num'])).sum(0)
+ sum_all_pad = np.where(sum_all == 0, 1., sum_all)
+ v_all = v_sum_weight / sum_all_pad
+ v_all = np.where(sum_all == 0, 0., v_all)
+ csv_w.writerow([k] + v_all.tolist())
+ else:
+ csv_w.writerow([k] + sum_all.tolist())
+
+ file_csv.close()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/eval_rec_all_en.py b/tools/eval_rec_all_en.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8c11c1a7e0e9883f17528f4d0367d4dd11a865c
--- /dev/null
+++ b/tools/eval_rec_all_en.py
@@ -0,0 +1,206 @@
+import csv
+import os
+import sys
+import numpy as np
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+from tools.data import build_dataloader
+from tools.engine import Config, Trainer
+from tools.utility import ArgsParser
+
+
+def parse_args():
+ parser = ArgsParser()
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ FLAGS = parse_args()
+ cfg = Config(FLAGS.config)
+ FLAGS = vars(FLAGS)
+ opt = FLAGS.pop('opt')
+ cfg.merge_dict(FLAGS)
+ cfg.merge_dict(opt)
+
+ msr = False
+ if 'RatioDataSet' in cfg.cfg['Eval']['dataset']['name']:
+ msr = True
+
+ if cfg.cfg['Global']['output_dir'][-1] == '/':
+ cfg.cfg['Global']['output_dir'] = cfg.cfg['Global']['output_dir'][:-1]
+ if cfg.cfg['Global']['pretrained_model'] is None:
+ cfg.cfg['Global'][
+ 'pretrained_model'] = cfg.cfg['Global']['output_dir'] + '/best.pth'
+ cfg.cfg['Global']['use_amp'] = False
+ cfg.cfg['PostProcess']['with_ratio'] = True
+ cfg.cfg['Metric']['with_ratio'] = True
+ cfg.cfg['Metric']['max_len'] = 25
+ cfg.cfg['Metric']['max_ratio'] = 12
+ cfg.cfg['Eval']['dataset']['transforms'][-1]['KeepKeys'][
+ 'keep_keys'].append('real_ratio')
+ trainer = Trainer(cfg, mode='eval')
+
+ best_model_dict = trainer.status.get('metrics', {})
+ trainer.logger.info('metric in ckpt ***************')
+ for k, v in best_model_dict.items():
+ trainer.logger.info('{}:{}'.format(k, v))
+
+ data_dirs_list = [
+ [
+ '../test/IIIT5k/', '../test/SVT/', '../test/IC13_857/',
+ '../test/IC15_1811/', '../test/SVTP/', '../test/CUTE80/'
+ ],
+ [
+ '../u14m/curve/', '../u14m/multi_oriented/', '../u14m/artistic/',
+ '../u14m/contextless/', '../u14m/salient/', '../u14m/multi_words/',
+ '../u14m/general/'
+ ], ['../OST/weak/', '../OST/heavy/'],
+ ['../wordart_test/', '../test/IC13_1015/', '../test/IC15_2077/']
+ ]
+ cfg = cfg.cfg
+ file_csv = open(
+ cfg['Global']['output_dir'] + '/' +
+ cfg['Global']['output_dir'].split('/')[-1] +
+ '_eval_all_length_ratio.csv', 'w')
+ csv_w = csv.writer(file_csv)
+ cfg['Eval']['dataset']['name'] = cfg['Eval']['dataset']['name'] + 'Test'
+ for data_dirs in data_dirs_list:
+
+ acc_each = []
+ acc_each_real = []
+ acc_each_lower = []
+ acc_each_ingore_space = []
+ acc_each_ingore_space_lower = []
+ acc_each_ignore_space_symbol = []
+ acc_each_lower_ignore_space_symbol = []
+ acc_each_num = []
+ acc_each_dis = []
+ each_len = {}
+ each_ratio = {}
+ for datadir in data_dirs:
+ config_each = cfg.copy()
+ if msr:
+ config_each['Eval']['dataset']['data_dir_list'] = [datadir]
+ else:
+ config_each['Eval']['dataset']['data_dir'] = datadir
+ valid_dataloader = build_dataloader(config_each, 'Eval',
+ trainer.logger)
+ trainer.logger.info(
+ f'{datadir} valid dataloader has {len(valid_dataloader)} iters'
+ )
+ trainer.valid_dataloader = valid_dataloader
+ metric = trainer.eval()
+ acc_each.append(metric['acc'] * 100)
+ acc_each_real.append(metric['acc_real'] * 100)
+ acc_each_lower.append(metric['acc_lower'] * 100)
+ acc_each_ingore_space.append(metric['acc_ignore_space'] * 100)
+ acc_each_ingore_space_lower.append(
+ metric['acc_ignore_space_lower'] * 100)
+ acc_each_ignore_space_symbol.append(
+ metric['acc_ignore_space_symbol'] * 100)
+ acc_each_lower_ignore_space_symbol.append(
+ metric['acc_ignore_space_lower_symbol'] * 100)
+ acc_each_dis.append(metric['norm_edit_dis'])
+ acc_each_num.append(metric['num_samples'])
+
+ trainer.logger.info('metric eval ***************')
+ csv_w.writerow([datadir])
+ for k, v in metric.items():
+ trainer.logger.info('{}:{}'.format(k, v))
+ if 'each' in k:
+ csv_w.writerow([k] + v)
+ if 'each_len' in k:
+ each_len[k] = each_len.get(k, []) + [np.array(v)]
+ if 'each_ratio' in k:
+ each_ratio[k] = each_ratio.get(k, []) + [np.array(v)]
+ data_name = [
+ data_n[:-1].split('/')[-1]
+ if data_n[-1] == '/' else data_n.split('/')[-1]
+ for data_n in data_dirs
+ ]
+ csv_w.writerow(['-'] + data_name + ['arithmetic_avg'] +
+ ['weighted_avg'])
+ csv_w.writerow([''] + acc_each_num)
+ avg1 = np.array(acc_each) * np.array(acc_each_num) / sum(acc_each_num)
+ csv_w.writerow(['acc'] + acc_each + [sum(acc_each) / len(acc_each)] +
+ [avg1.sum().tolist()])
+ print(acc_each + [sum(acc_each) / len(acc_each)] +
+ [avg1.sum().tolist()])
+ avg1 = np.array(acc_each_dis) * np.array(acc_each_num) / sum(
+ acc_each_num)
+ csv_w.writerow(['norm_edit_dis'] + acc_each_dis +
+ [sum(acc_each_dis) / len(acc_each)] +
+ [avg1.sum().tolist()])
+
+ avg1 = np.array(acc_each_real) * np.array(acc_each_num) / sum(
+ acc_each_num)
+ csv_w.writerow(['acc_real'] + acc_each_real +
+ [sum(acc_each_real) / len(acc_each_real)] +
+ [avg1.sum().tolist()])
+ avg1 = np.array(acc_each_lower) * np.array(acc_each_num) / sum(
+ acc_each_num)
+ csv_w.writerow(['acc_lower'] + acc_each_lower +
+ [sum(acc_each_lower) / len(acc_each_lower)] +
+ [avg1.sum().tolist()])
+ avg1 = np.array(acc_each_ingore_space) * np.array(acc_each_num) / sum(
+ acc_each_num)
+ csv_w.writerow(
+ ['acc_ignore_space'] + acc_each_ingore_space +
+ [sum(acc_each_ingore_space) / len(acc_each_ingore_space)] +
+ [avg1.sum().tolist()])
+ avg1 = np.array(acc_each_ingore_space_lower) * np.array(
+ acc_each_num) / sum(acc_each_num)
+ csv_w.writerow(['acc_ignore_space_lower'] +
+ acc_each_ingore_space_lower + [
+ sum(acc_each_ingore_space_lower) /
+ len(acc_each_ingore_space_lower)
+ ] + [avg1.sum().tolist()])
+ avg1 = np.array(acc_each_ignore_space_symbol) * np.array(
+ acc_each_num) / sum(acc_each_num)
+ csv_w.writerow(['acc_ignore_space_symbol'] +
+ acc_each_ignore_space_symbol + [
+ sum(acc_each_ignore_space_symbol) /
+ len(acc_each_ignore_space_symbol)
+ ] + [avg1.sum().tolist()])
+ avg1 = np.array(acc_each_lower_ignore_space_symbol) * np.array(
+ acc_each_num) / sum(acc_each_num)
+ csv_w.writerow(['acc_ignore_space_lower_symbol'] +
+ acc_each_lower_ignore_space_symbol + [
+ sum(acc_each_lower_ignore_space_symbol) /
+ len(acc_each_lower_ignore_space_symbol)
+ ] + [avg1.sum().tolist()])
+
+ sum_all = np.array(each_len['each_len_num']).sum(0)
+ for k, v in each_len.items():
+ if k != 'each_len_num':
+ v_sum_weight = (np.array(v) *
+ np.array(each_len['each_len_num'])).sum(0)
+ sum_all_pad = np.where(sum_all == 0, 1., sum_all)
+ v_all = v_sum_weight / sum_all_pad
+ v_all = np.where(sum_all == 0, 0., v_all)
+ csv_w.writerow([k] + v_all.tolist())
+ else:
+ csv_w.writerow([k] + sum_all.tolist())
+
+ sum_all = np.array(each_ratio['each_ratio_num']).sum(0)
+ for k, v in each_ratio.items():
+ if k != 'each_ratio_num':
+ v_sum_weight = (np.array(v) *
+ np.array(each_ratio['each_ratio_num'])).sum(0)
+ sum_all_pad = np.where(sum_all == 0, 1., sum_all)
+ v_all = v_sum_weight / sum_all_pad
+ v_all = np.where(sum_all == 0, 0., v_all)
+ csv_w.writerow([k] + v_all.tolist())
+ else:
+ csv_w.writerow([k] + sum_all.tolist())
+
+ file_csv.close()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/eval_rec_all_long.py b/tools/eval_rec_all_long.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae60e5babfb9296f94ed6b0b0efc773b621cd2bb
--- /dev/null
+++ b/tools/eval_rec_all_long.py
@@ -0,0 +1,119 @@
+import csv
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+import numpy as np
+
+from tools.data import build_dataloader
+from tools.engine import Config, Trainer
+from tools.utility import ArgsParser
+
+
+def parse_args():
+ parser = ArgsParser()
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ FLAGS = parse_args()
+ cfg = Config(FLAGS.config)
+ FLAGS = vars(FLAGS)
+ opt = FLAGS.pop('opt')
+ cfg.merge_dict(FLAGS)
+ cfg.merge_dict(opt)
+
+ cfg.cfg['Global']['use_amp'] = False
+ if cfg.cfg['Global']['output_dir'][-1] == '/':
+ cfg.cfg['Global']['output_dir'] = cfg.cfg['Global']['output_dir'][:-1]
+ cfg.cfg['Global']['max_text_length'] = 200
+ cfg.cfg['Architecture']['Decoder']['max_len'] = 200
+ cfg.cfg['Metric']['name'] = 'RecMetricLong'
+ if cfg.cfg['Global']['pretrained_model'] is None:
+ cfg.cfg['Global'][
+ 'pretrained_model'] = cfg.cfg['Global']['output_dir'] + '/best.pth'
+ trainer = Trainer(cfg, mode='eval')
+
+ best_model_dict = trainer.status.get('metrics', {})
+ trainer.logger.info('metric in ckpt ***************')
+ for k, v in best_model_dict.items():
+ trainer.logger.info('{}:{}'.format(k, v))
+
+ data_dirs_list = [
+ ['../ltb/long_lmdb'],
+ ]
+
+ cfg = cfg.cfg
+ file_csv = open(
+ cfg['Global']['output_dir'] + '/' +
+ cfg['Global']['output_dir'].split('/')[-1] +
+ '_result1_1_test_all_long_final_ultra_bs1.csv', 'w')
+ csv_w = csv.writer(file_csv)
+
+ for data_dirs in data_dirs_list:
+ acc_each = []
+ acc_each_num = []
+ acc_each_dis = []
+ each_long = {}
+ for datadir in data_dirs:
+ config_each = cfg.copy()
+
+ config_each['Eval']['dataset']['data_dir_list'] = [datadir]
+ valid_dataloader = build_dataloader(config_each, 'Eval',
+ trainer.logger)
+ trainer.logger.info(
+ f'{datadir} valid dataloader has {len(valid_dataloader)} iters'
+ )
+ trainer.valid_dataloader = valid_dataloader
+ metric = trainer.eval()
+ acc_each.append(metric['acc'] * 100)
+ acc_each_dis.append(metric['norm_edit_dis'])
+ acc_each_num.append(metric['all_num'])
+
+ trainer.logger.info('metric eval ***************')
+ for k, v in metric.items():
+ trainer.logger.info('{}:{}'.format(k, v))
+ if 'each' in k:
+ csv_w.writerow([k] + v[26:])
+ each_long[k] = each_long.get(k, []) + [np.array(v[26:])]
+ avg1 = np.array(acc_each) * np.array(acc_each_num) / sum(acc_each_num)
+ csv_w.writerow(acc_each + [avg1.sum().tolist()] +
+ [sum(acc_each) / len(acc_each)])
+ print(acc_each + [avg1.sum().tolist()] +
+ [sum(acc_each) / len(acc_each)])
+ avg1 = np.array(acc_each_dis) * np.array(acc_each_num) / sum(
+ acc_each_num)
+ csv_w.writerow(acc_each_dis + [avg1.sum().tolist()] +
+ [sum(acc_each_dis) / len(acc_each)])
+
+ sum_all = np.array(each_long['each_len_num']).sum(0)
+ for k, v in each_long.items():
+ if k != 'each_len_num':
+ v_sum_weight = (np.array(v) *
+ np.array(each_long['each_len_num'])).sum(0)
+ sum_all_pad = np.where(sum_all == 0, 1., sum_all)
+ v_all = v_sum_weight / sum_all_pad
+ v_all = np.where(sum_all == 0, 0., v_all)
+ csv_w.writerow([k] + v_all.tolist())
+ v_26_40 = (v_all[:10] * sum_all[:10]) / sum_all[:10].sum()
+ csv_w.writerow([k + '26_35'] + [v_26_40.sum().tolist()] +
+ [sum_all[:10].sum().tolist()])
+ v_41_55 = (v_all[10:30] *
+ sum_all[10:30]) / sum_all[10:30].sum()
+ csv_w.writerow([k + '36_55'] + [v_41_55.sum().tolist()] +
+ [sum_all[10:30].sum().tolist()])
+ v_56_70 = (v_all[30:] * sum_all[30:]) / sum_all[30:].sum()
+ csv_w.writerow([k + '56'] + [v_56_70.sum().tolist()] +
+ [sum_all[30:].sum().tolist()])
+ else:
+ csv_w.writerow([k] + sum_all.tolist())
+ file_csv.close()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/eval_rec_all_long_simple.py b/tools/eval_rec_all_long_simple.py
new file mode 100644
index 0000000000000000000000000000000000000000..a86d4dfd48ba5375e2bfd7617c3f1465e47e48fd
--- /dev/null
+++ b/tools/eval_rec_all_long_simple.py
@@ -0,0 +1,122 @@
+import csv
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+import numpy as np
+
+from tools.data import build_dataloader
+from tools.engine import Config, Trainer
+from tools.utility import ArgsParser
+
+
+def parse_args():
+ parser = ArgsParser()
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ FLAGS = parse_args()
+ cfg = Config(FLAGS.config)
+ FLAGS = vars(FLAGS)
+ opt = FLAGS.pop('opt')
+ cfg.merge_dict(FLAGS)
+ cfg.merge_dict(opt)
+
+ cfg.cfg['Global']['use_amp'] = False
+ if cfg.cfg['Global']['output_dir'][-1] == '/':
+ cfg.cfg['Global']['output_dir'] = cfg.cfg['Global']['output_dir'][:-1]
+ cfg.cfg['Global']['max_text_length'] = 200
+ cfg.cfg['Architecture']['Decoder']['max_len'] = 200
+ cfg.cfg['Metric']['name'] = 'RecMetricLong'
+ if cfg.cfg['Global']['pretrained_model'] is None:
+ cfg.cfg['Global'][
+ 'pretrained_model'] = cfg.cfg['Global']['output_dir'] + '/best.pth'
+ trainer = Trainer(cfg, mode='eval')
+
+ best_model_dict = trainer.status.get('metrics', {})
+ trainer.logger.info('metric in ckpt ***************')
+ for k, v in best_model_dict.items():
+ trainer.logger.info('{}:{}'.format(k, v))
+
+ data_dirs_list = [
+ [
+ '../ltb/ultra_long_26_35_list.txt',
+ '../ltb/ultra_long_36_55_list.txt',
+ '../ltb/ultra_long_56_list.txt',
+ ],
+ ]
+
+ cfg = cfg.cfg
+ cfg['Eval']['dataset']['name'] = 'SimpleDataSet'
+ file_csv = open(
+ cfg['Global']['output_dir'] + '/' +
+ cfg['Global']['output_dir'].split('/')[-1] +
+ '_result1_1_test_all_long_simple_bi_bs1.csv', 'w')
+ csv_w = csv.writer(file_csv)
+
+ for data_dirs in data_dirs_list:
+ acc_each = []
+ acc_each_num = []
+ acc_each_dis = []
+ each_long = {}
+ for datadir in data_dirs:
+ config_each = cfg.copy()
+ config_each['Eval']['dataset']['label_file_list'] = [datadir]
+ valid_dataloader = build_dataloader(config_each, 'Eval',
+ trainer.logger)
+ trainer.logger.info(
+ f'{datadir} valid dataloader has {len(valid_dataloader)} iters'
+ )
+ trainer.valid_dataloader = valid_dataloader
+ metric = trainer.eval()
+ acc_each.append(metric['acc'] * 100)
+ acc_each_dis.append(metric['norm_edit_dis'])
+ acc_each_num.append(metric['all_num'])
+ trainer.logger.info('metric eval ***************')
+ for k, v in metric.items():
+ trainer.logger.info('{}:{}'.format(k, v))
+ if 'each' in k:
+ csv_w.writerow([k] + v[26:])
+ each_long[k] = each_long.get(k, []) + [np.array(v[26:])]
+ avg1 = np.array(acc_each) * np.array(acc_each_num) / sum(acc_each_num)
+ csv_w.writerow(acc_each + [avg1.sum().tolist()] +
+ [sum(acc_each) / len(acc_each)])
+ print(acc_each + [avg1.sum().tolist()] +
+ [sum(acc_each) / len(acc_each)])
+ avg1 = np.array(acc_each_dis) * np.array(acc_each_num) / sum(
+ acc_each_num)
+ csv_w.writerow(acc_each_dis + [avg1.sum().tolist()] +
+ [sum(acc_each_dis) / len(acc_each)])
+
+ sum_all = np.array(each_long['each_len_num']).sum(0)
+ for k, v in each_long.items():
+ if k != 'each_len_num':
+ v_sum_weight = (np.array(v) *
+ np.array(each_long['each_len_num'])).sum(0)
+ sum_all_pad = np.where(sum_all == 0, 1., sum_all)
+ v_all = v_sum_weight / sum_all_pad
+ v_all = np.where(sum_all == 0, 0., v_all)
+ csv_w.writerow([k] + v_all.tolist())
+ v_26_40 = (v_all[:10] * sum_all[:10]) / sum_all[:10].sum()
+ csv_w.writerow([k + '26_35'] + [v_26_40.sum().tolist()] +
+ [sum_all[:10].sum().tolist()])
+ v_41_55 = (v_all[10:30] *
+ sum_all[10:30]) / sum_all[10:30].sum()
+ csv_w.writerow([k + '36_55'] + [v_41_55.sum().tolist()] +
+ [sum_all[10:30].sum().tolist()])
+ v_56_70 = (v_all[30:] * sum_all[30:]) / sum_all[30:].sum()
+ csv_w.writerow([k + '56'] + [v_56_70.sum().tolist()] +
+ [sum_all[30:].sum().tolist()])
+ else:
+ csv_w.writerow([k] + sum_all.tolist())
+ file_csv.close()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/export_rec.py b/tools/export_rec.py
new file mode 100644
index 0000000000000000000000000000000000000000..882ed3aa066045dcbff78b49b3db3fae8e57658c
--- /dev/null
+++ b/tools/export_rec.py
@@ -0,0 +1,118 @@
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+import torch
+
+from openrec.modeling import build_model
+from openrec.postprocess import build_post_process
+from tools.engine import Config
+from tools.infer_rec import build_rec_process
+from tools.utility import ArgsParser
+from tools.utils.ckpt import load_ckpt
+from tools.utils.logging import get_logger
+
+
+def to_onnx(model, dummy_input, dynamic_axes, sava_path='model.onnx'):
+ input_axis_name = ['batch_size', 'channel', 'in_width', 'int_height']
+ output_axis_name = ['batch_size', 'channel', 'out_width', 'out_height']
+ torch.onnx.export(
+ model.to('cpu'),
+ dummy_input,
+ sava_path,
+ input_names=['input'],
+ output_names=['output'], # the model's output names
+ dynamic_axes={
+ 'input': {axis: input_axis_name[axis]
+ for axis in dynamic_axes},
+ 'output': {axis: output_axis_name[axis]
+ for axis in dynamic_axes},
+ },
+ )
+
+
+def export_single_model(model: torch.nn.Module, _cfg, export_dir,
+ export_config, logger, type):
+ for layer in model.modules():
+ if hasattr(layer, 'rep') and not getattr(layer, 'is_repped'):
+ layer.rep()
+ os.makedirs(export_dir, exist_ok=True)
+
+ export_cfg = {'PostProcess': _cfg['PostProcess']}
+ export_cfg['Transforms'] = build_rec_process(_cfg)
+
+ cfg.save(os.path.join(export_dir, 'config.yaml'), export_cfg)
+
+ dummy_input = torch.randn(*export_config['export_shape'], device='cpu')
+ if type == 'script':
+ save_path = os.path.join(export_dir, 'model.pt')
+ trace_model = torch.jit.trace(model, dummy_input, strict=False)
+ torch.jit.save(trace_model, save_path)
+ elif type == 'onnx':
+ save_path = os.path.join(export_dir, 'model.onnx')
+ to_onnx(model, dummy_input, export_config.get('dynamic_axes', []),
+ save_path)
+ else:
+ raise NotImplementedError
+ logger.info(f'finish export model to {save_path}')
+
+
+def main(cfg, type):
+ _cfg = cfg.cfg
+ logger = get_logger()
+ global_config = _cfg['Global']
+ export_config = _cfg['Export']
+ # build post process
+ post_process_class = build_post_process(_cfg['PostProcess'])
+ char_num = len(getattr(post_process_class, 'character'))
+ cfg['Architecture']['Decoder']['out_channels'] = char_num
+ model = build_model(_cfg['Architecture'])
+
+ load_ckpt(model, _cfg)
+ model.eval()
+
+ export_dir = export_config.get('export_dir', '')
+ if not export_dir:
+ export_dir = os.path.join(global_config.get('output_dir', 'output'),
+ 'export')
+
+ if _cfg['Architecture']['algorithm'] in ['Distillation'
+ ]: # distillation model
+ _cfg['PostProcess'][
+ 'name'] = post_process_class.__class__.__base__.__name__
+ for model_name in model.model_list:
+ sub_model_save_path = os.path.join(export_dir, model_name)
+ export_single_model(
+ model.model_list[model_name],
+ _cfg,
+ sub_model_save_path,
+ export_config,
+ logger,
+ type,
+ )
+ else:
+ export_single_model(model, _cfg, export_dir, export_config, logger,
+ type)
+
+
+def parse_args():
+ parser = ArgsParser()
+ parser.add_argument('--type',
+ type=str,
+ default='onnx',
+ help='type of export')
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ FLAGS = parse_args()
+ cfg = Config(FLAGS.config)
+ FLAGS = vars(FLAGS)
+ opt = FLAGS.pop('opt')
+ cfg.merge_dict(FLAGS)
+ cfg.merge_dict(opt)
+ main(cfg, FLAGS['type'])
diff --git a/tools/infer/__pycache__/utility.cpython-38.pyc b/tools/infer/__pycache__/utility.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8c066f9e709a45e86523d357426a5a182855386
Binary files /dev/null and b/tools/infer/__pycache__/utility.cpython-38.pyc differ
diff --git a/tools/infer/onnx_engine.py b/tools/infer/onnx_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1a774febd9dba03e4fb6454f1e9d37be4412a8b
--- /dev/null
+++ b/tools/infer/onnx_engine.py
@@ -0,0 +1,65 @@
+import os
+
+import onnxruntime
+
+
+class ONNXEngine:
+
+ def __init__(self, onnx_path, use_gpu):
+ """
+ :param onnx_path:
+ """
+ if not os.path.exists(onnx_path):
+ raise Exception(f'{onnx_path} is not exists')
+
+ providers = ['CPUExecutionProvider']
+ if use_gpu:
+ providers = ([
+ 'TensorrtExecutionProvider',
+ 'CUDAExecutionProvider',
+ 'CPUExecutionProvider',
+ ], )
+ self.onnx_session = onnxruntime.InferenceSession(onnx_path,
+ providers=providers)
+ self.input_name = self.get_input_name(self.onnx_session)
+ self.output_name = self.get_output_name(self.onnx_session)
+
+ def get_output_name(self, onnx_session):
+ """
+ output_name = onnx_session.get_outputs()[0].name
+ :param onnx_session:
+ :return:
+ """
+ output_name = []
+ for node in onnx_session.get_outputs():
+ output_name.append(node.name)
+ return output_name
+
+ def get_input_name(self, onnx_session):
+ """
+ input_name = onnx_session.get_inputs()[0].name
+ :param onnx_session:
+ :return:
+ """
+ input_name = []
+ for node in onnx_session.get_inputs():
+ input_name.append(node.name)
+ return input_name
+
+ def get_input_feed(self, input_name, image_numpy):
+ """
+ input_feed={self.input_name: image_numpy}
+ :param input_name:
+ :param image_numpy:
+ :return:
+ """
+ input_feed = {}
+ for name in input_name:
+ input_feed[name] = image_numpy
+ return input_feed
+
+ def run(self, image_numpy):
+ # 输入数据的类型必须与模型一致,以下三种写法都是可以的
+ input_feed = self.get_input_feed(self.input_name, image_numpy)
+ result = self.onnx_session.run(self.output_name, input_feed=input_feed)
+ return result
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
new file mode 100644
index 0000000000000000000000000000000000000000..09060149c25e5ce9136a2bacb46bef1566e3415a
--- /dev/null
+++ b/tools/infer/predict_rec.py
@@ -0,0 +1,140 @@
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+
+import math
+import time
+
+import cv2
+import numpy as np
+
+from openrec.postprocess import build_post_process
+from openrec.preprocess import create_operators, transform
+from tools.engine import Config
+from tools.infer.onnx_engine import ONNXEngine
+from tools.infer.utility import check_gpu, parse_args
+from tools.utils.logging import get_logger
+from tools.utils.utility import check_and_read, get_image_file_list
+
+logger = get_logger()
+
+
+class TextRecognizer(ONNXEngine):
+
+ def __init__(self, args):
+ if args.rec_model_dir is None or not os.path.exists(
+ args.rec_model_dir):
+ raise Exception(
+ f'args.rec_model_dir is set to {args.rec_model_dir}, but it is not exists'
+ )
+
+ onnx_path = os.path.join(args.rec_model_dir, 'model.onnx')
+ config_path = os.path.join(args.rec_model_dir, 'config.yaml')
+ super(TextRecognizer, self).__init__(onnx_path, args.use_gpu)
+
+ self.rec_image_shape = [
+ int(v) for v in args.rec_image_shape.split(',')
+ ]
+ self.rec_batch_num = args.rec_batch_num
+ self.rec_algorithm = args.rec_algorithm
+
+ cfg = Config(config_path).cfg
+ self.ops = create_operators(cfg['Transforms'][1:])
+ self.postprocess_op = build_post_process(cfg['PostProcess'])
+
+ def resize_norm_img(self, img, max_wh_ratio):
+ imgC, imgH, imgW = self.rec_image_shape
+ assert imgC == img.shape[2]
+ imgW = int((imgH * max_wh_ratio))
+ h, w = img.shape[:2]
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ resized_image = resized_image.astype('float32')
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+ padding_im[:, :, 0:resized_w] = resized_image
+ return padding_im
+
+ def __call__(self, img_list):
+ img_num = len(img_list)
+ # Calculate the aspect ratio of all text bars
+ width_list = []
+ for img in img_list:
+ width_list.append(img.shape[1] / float(img.shape[0]))
+ # Sorting can speed up the recognition process
+ indices = np.argsort(np.array(width_list))
+ rec_res = [['', 0.0]] * img_num
+ batch_num = self.rec_batch_num
+ st = time.time()
+ for beg_img_no in range(0, img_num, batch_num):
+ end_img_no = min(img_num, beg_img_no + batch_num)
+ norm_img_batch = []
+ imgC, imgH, imgW = self.rec_image_shape[:3]
+ max_wh_ratio = imgW / imgH
+ # max_wh_ratio = 0
+ for ino in range(beg_img_no, end_img_no):
+ h, w = img_list[indices[ino]].shape[0:2]
+ wh_ratio = w * 1.0 / h
+ max_wh_ratio = max(max_wh_ratio, wh_ratio)
+ for ino in range(beg_img_no, end_img_no):
+ if self.rec_algorithm == 'nrtr':
+ norm_img = transform({'image': img_list[indices[ino]]},
+ self.ops)[0]
+ else:
+ norm_img = self.resize_norm_img(img_list[indices[ino]],
+ max_wh_ratio)
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
+ norm_img_batch = np.concatenate(norm_img_batch)
+ norm_img_batch = norm_img_batch.copy()
+
+ preds = self.run(norm_img_batch)
+
+ if len(preds) == 1:
+ preds = preds[0]
+
+ rec_result = self.postprocess_op({'res': preds})
+ for rno in range(len(rec_result)):
+ rec_res[indices[beg_img_no + rno]] = rec_result[rno]
+ return rec_res, time.time() - st
+
+
+def main(args):
+ args.use_gpu = check_gpu(args.use_gpu)
+
+ image_file_list = get_image_file_list(args.image_dir)
+ text_recognizer = TextRecognizer(args)
+ valid_image_file_list = []
+ img_list = []
+
+ # warmup 2 times
+ if args.warmup:
+ img = np.random.uniform(0, 255, [48, 320, 3]).astype(np.uint8)
+ for i in range(2):
+ text_recognizer([img] * int(args.rec_batch_num))
+
+ for image_file in image_file_list:
+ img, flag, _ = check_and_read(image_file)
+ if not flag:
+ img = cv2.imread(image_file)
+ if img is None:
+ logger.info(f'error in loading image:{image_file}')
+ continue
+ valid_image_file_list.append(image_file)
+ img_list.append(img)
+ rec_res, _ = text_recognizer(img_list)
+ for ino in range(len(img_list)):
+ logger.info(f'result of {valid_image_file_list[ino]}:{rec_res[ino]}')
+
+
+if __name__ == '__main__':
+ main(parse_args())
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
new file mode 100644
index 0000000000000000000000000000000000000000..944f8cb390d8248e7dae711b4a478c33db89edd8
--- /dev/null
+++ b/tools/infer/utility.py
@@ -0,0 +1,234 @@
+import argparse
+import math
+
+import cv2
+import numpy as np
+import torch
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+import random
+
+def str2bool(v):
+ return v.lower() in ('true', 'yes', 't', 'y', '1')
+
+
+def str2int_tuple(v):
+ return tuple([int(i.strip()) for i in v.split(',')])
+
+
+def init_args():
+ parser = argparse.ArgumentParser()
+ # params for prediction engine
+ parser.add_argument('--use_gpu', type=str2bool, default=False)
+
+ # params for text detector
+ parser.add_argument('--image_dir', type=str)
+ parser.add_argument('--det_algorithm', type=str, default='DB')
+ parser.add_argument('--det_model_dir', type=str)
+ parser.add_argument('--det_limit_side_len', type=float, default=960)
+ parser.add_argument('--det_limit_type', type=str, default='max')
+ parser.add_argument('--det_box_type', type=str, default='quad')
+
+ # DB parmas
+ parser.add_argument('--det_db_thresh', type=float, default=0.3)
+ parser.add_argument('--det_db_box_thresh', type=float, default=0.6)
+ parser.add_argument('--det_db_unclip_ratio', type=float, default=1.5)
+ parser.add_argument('--max_batch_size', type=int, default=10)
+ parser.add_argument('--use_dilation', type=str2bool, default=False)
+ parser.add_argument('--det_db_score_mode', type=str, default='fast')
+
+ # params for text recognizer
+ parser.add_argument('--rec_algorithm', type=str, default='SVTR_LCNet')
+ parser.add_argument('--rec_model_dir', type=str)
+ parser.add_argument('--rec_image_inverse', type=str2bool, default=True)
+ parser.add_argument('--rec_image_shape', type=str, default='3, 48, 320')
+ parser.add_argument('--rec_batch_num', type=int, default=6)
+ parser.add_argument('--max_text_length', type=int, default=25)
+ parser.add_argument('--vis_font_path',
+ type=str,
+ default='./doc/fonts/simfang.ttf')
+ parser.add_argument('--drop_score', type=float, default=0.5)
+
+ # params for text classifier
+ parser.add_argument('--use_angle_cls', type=str2bool, default=False)
+ parser.add_argument('--cls_model_dir', type=str)
+ parser.add_argument('--cls_image_shape', type=str, default='3, 48, 192')
+ parser.add_argument('--label_list', type=list, default=['0', '180'])
+ parser.add_argument('--cls_batch_num', type=int, default=6)
+ parser.add_argument('--cls_thresh', type=float, default=0.9)
+
+ parser.add_argument('--warmup', type=str2bool, default=False)
+
+ #
+ parser.add_argument('--output', type=str, default='./inference_results')
+ parser.add_argument('--save_crop_res', type=str2bool, default=False)
+ parser.add_argument('--crop_res_save_dir', type=str, default='./output')
+
+ # multi-process
+ parser.add_argument('--use_mp', type=str2bool, default=False)
+ parser.add_argument('--total_process_num', type=int, default=1)
+ parser.add_argument('--process_id', type=int, default=0)
+
+ parser.add_argument('--show_log', type=str2bool, default=True)
+ return parser
+
+
+def parse_args():
+ parser = init_args()
+ return parser.parse_args()
+
+def create_font(txt, sz, font_path="./doc/fonts/simfang.ttf"):
+ font_size = int(sz[1] * 0.99)
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+ if int(PIL.__version__.split(".")[0]) < 10:
+ length = font.getsize(txt)[0]
+ else:
+ length = font.getlength(txt)
+
+ if length > sz[0]:
+ font_size = int(font_size * sz[0] / length)
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+ return font
+
+def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"):
+ box_height = int(
+ math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
+ )
+ box_width = int(
+ math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
+ )
+
+ if box_height > 2 * box_width and box_height > 30:
+ img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255))
+ draw_text = ImageDraw.Draw(img_text)
+ if txt:
+ font = create_font(txt, (box_height, box_width), font_path)
+ draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
+ img_text = img_text.transpose(Image.ROTATE_270)
+ else:
+ img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255))
+ draw_text = ImageDraw.Draw(img_text)
+ if txt:
+ font = create_font(txt, (box_width, box_height), font_path)
+ draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
+
+ pts1 = np.float32(
+ [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
+ )
+ pts2 = np.array(box, dtype=np.float32)
+ M = cv2.getPerspectiveTransform(pts1, pts2)
+
+ img_text = np.array(img_text, dtype=np.uint8)
+ img_right_text = cv2.warpPerspective(
+ img_text,
+ M,
+ img_size,
+ flags=cv2.INTER_NEAREST,
+ borderMode=cv2.BORDER_CONSTANT,
+ borderValue=(255, 255, 255),
+ )
+ return img_right_text
+
+def draw_ocr_box_txt(
+ image,
+ boxes,
+ txts=None,
+ scores=None,
+ drop_score=0.5,
+ font_path="./doc/fonts/simfang.ttf",
+):
+ h, w = image.height, image.width
+ img_left = image.copy()
+ img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
+ random.seed(0)
+
+ draw_left = ImageDraw.Draw(img_left)
+ if txts is None or len(txts) != len(boxes):
+ txts = [None] * len(boxes)
+ for idx, (box, txt) in enumerate(zip(boxes, txts)):
+ if scores is not None and scores[idx] < drop_score:
+ continue
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
+ if isinstance(box[0], list):
+ box = list(map(tuple, box))
+ draw_left.polygon(box, fill=color)
+ img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
+ pts = np.array(box, np.int32).reshape((-1, 1, 2))
+ cv2.polylines(img_right_text, [pts], True, color, 1)
+ img_right = cv2.bitwise_and(img_right, img_right_text)
+ img_left = Image.blend(image, img_left, 0.5)
+ img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
+ img_show.paste(img_left, (0, 0, w, h))
+ img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
+ return np.array(img_show)
+
+
+def get_rotate_crop_image(img, points):
+ """
+ img_height, img_width = img.shape[0:2]
+ left = int(np.min(points[:, 0]))
+ right = int(np.max(points[:, 0]))
+ top = int(np.min(points[:, 1]))
+ bottom = int(np.max(points[:, 1]))
+ img_crop = img[top:bottom, left:right, :].copy()
+ points[:, 0] = points[:, 0] - left
+ points[:, 1] = points[:, 1] - top
+ """
+ assert len(points) == 4, 'shape of points must be 4*2'
+ img_crop_width = int(
+ max(np.linalg.norm(points[0] - points[1]),
+ np.linalg.norm(points[2] - points[3])))
+ img_crop_height = int(
+ max(np.linalg.norm(points[0] - points[3]),
+ np.linalg.norm(points[1] - points[2])))
+ pts_std = np.float32([
+ [0, 0],
+ [img_crop_width, 0],
+ [img_crop_width, img_crop_height],
+ [0, img_crop_height],
+ ])
+ M = cv2.getPerspectiveTransform(points, pts_std)
+ dst_img = cv2.warpPerspective(
+ img,
+ M,
+ (img_crop_width, img_crop_height),
+ borderMode=cv2.BORDER_REPLICATE,
+ flags=cv2.INTER_CUBIC,
+ )
+ dst_img_height, dst_img_width = dst_img.shape[0:2]
+ if dst_img_height * 1.0 / dst_img_width >= 1.5:
+ dst_img = np.rot90(dst_img)
+ return dst_img
+
+
+def get_minarea_rect_crop(img, points):
+ bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
+ points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+ index_a, index_b, index_c, index_d = 0, 1, 2, 3
+ if points[1][1] > points[0][1]:
+ index_a = 0
+ index_d = 1
+ else:
+ index_a = 1
+ index_d = 0
+ if points[3][1] > points[2][1]:
+ index_b = 2
+ index_c = 3
+ else:
+ index_b = 3
+ index_c = 2
+
+ box = [points[index_a], points[index_b], points[index_c], points[index_d]]
+ crop_img = get_rotate_crop_image(img, np.array(box))
+ return crop_img
+
+
+def check_gpu(use_gpu):
+ if use_gpu and not torch.cuda.is_available():
+ use_gpu = False
+ return use_gpu
+
+
+if __name__ == '__main__':
+ pass
diff --git a/tools/infer_det.py b/tools/infer_det.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf1d2a711820d69d219723988794160ac2e5ab88
--- /dev/null
+++ b/tools/infer_det.py
@@ -0,0 +1,459 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from pathlib import Path
+import time
+
+import numpy as np
+
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
+
+import cv2
+import json
+import torch
+from tools.engine import Config
+from tools.utility import ArgsParser
+from tools.utils.ckpt import load_ckpt
+from tools.utils.logging import get_logger
+from tools.utils.utility import get_image_file_list
+
+logger = get_logger()
+
+root_dir = Path(__file__).resolve().parent
+DEFAULT_CFG_PATH_DET = str(root_dir / '../configs/det/dbnet/repvit_db.yml')
+
+MODEL_NAME_DET = './openocr_det_repvit_ch.pth' # 模型文件名称
+DOWNLOAD_URL_DET = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth' # 模型文件 URL
+
+
+def check_and_download_model(model_name: str, url: str):
+ """
+ 检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。
+
+ Args:
+ model_name (str): 模型文件的名称,例如 "model.pt"
+ url (str): 模型文件的下载地址
+
+ Returns:
+ str: 模型文件的完整路径
+ """
+ if os.path.exists(model_name):
+ return model_name
+
+ # 固定缓存路径为用户主目录下的 ".cache/openocr"
+ cache_dir = Path.home() / '.cache' / 'openocr'
+ model_path = cache_dir / model_name
+
+ # 如果模型文件已存在,直接返回路径
+ if model_path.exists():
+ logger.info(f'Model already exists at: {model_path}')
+ return str(model_path)
+
+ # 如果文件不存在,下载模型
+ logger.info(f'Model not found. Downloading from {url}...')
+
+ # 创建缓存目录(如果不存在)
+ cache_dir.mkdir(parents=True, exist_ok=True)
+
+ try:
+ # 下载文件
+ import urllib.request
+ with urllib.request.urlopen(url) as response, open(model_path,
+ 'wb') as out_file:
+ out_file.write(response.read())
+ logger.info(f'Model downloaded and saved at: {model_path}')
+ return str(model_path)
+
+ except Exception as e:
+ logger.error(f'Error downloading the model: {e}')
+ # 提示用户手动下载
+ logger.error(
+ f'Unable to download the model automatically. '
+ f'Please download the model manually from the following URL:\n{url}\n'
+ f'and save it to: {model_name} or {model_path}')
+ raise RuntimeError(
+ f'Failed to download the model. Please download it manually from {url} '
+ f'and save it to {model_path}') from e
+
+
+def replace_batchnorm(net):
+ for child_name, child in net.named_children():
+ if hasattr(child, 'fuse'):
+ fused = child.fuse()
+ setattr(net, child_name, fused)
+ replace_batchnorm(fused)
+ elif isinstance(child, torch.nn.BatchNorm2d):
+ setattr(net, child_name, torch.nn.Identity())
+ else:
+ replace_batchnorm(child)
+
+
+def padding_image(img, size=(640, 640)):
+ """
+ Padding an image using OpenCV:
+ - If the image is smaller than the target size, pad it to 640x640.
+ - If the image is larger than the target size, split it into multiple 640x640 images and record positions.
+
+ :param image_path: Path to the input image.
+ :param output_dir: Directory to save the output images.
+ :param size: The target size for padding or splitting (default 640x640).
+ :return: List of tuples containing the coordinates of the top-left corner of each cropped 640x640 image.
+ """
+
+ img_height, img_width = img.shape[:2]
+ target_width, target_height = size
+
+ # If image is smaller than target size, pad the image to 640x640
+
+ # Calculate padding amounts (top, bottom, left, right)
+ pad_top = 0
+ pad_bottom = target_height - img_height
+ pad_left = 0
+ pad_right = target_width - img_width
+
+ # Pad the image (white padding, border type: constant)
+ padded_img = cv2.copyMakeBorder(img,
+ pad_top,
+ pad_bottom,
+ pad_left,
+ pad_right,
+ cv2.BORDER_CONSTANT,
+ value=[0, 0, 0])
+
+ # Return the padded area positions (top-left and bottom-right coordinates of the original image)
+ return padded_img
+
+
+def resize_image(img, size=(640, 640), over_lap=64):
+ """
+ Resize an image using OpenCV:
+ - If the image is smaller than the target size, pad it to 640x640.
+ - If the image is larger than the target size, split it into multiple 640x640 images and record positions.
+
+ :param image_path: Path to the input image.
+ :param output_dir: Directory to save the output images.
+ :param size: The target size for padding or splitting (default 640x640).
+ :return: List of tuples containing the coordinates of the top-left corner of each cropped 640x640 image.
+ """
+
+ img_height, img_width = img.shape[:2]
+ target_width, target_height = size
+
+ # If image is smaller than target size, pad the image to 640x640
+ if img_width <= target_width and img_height <= target_height:
+ # Calculate padding amounts (top, bottom, left, right)
+ if img_width == target_width and img_height == target_height:
+ return [img], [[0, 0, img_width, img_height]]
+ padded_img = padding_image(img, size)
+
+ # Return the padded area positions (top-left and bottom-right coordinates of the original image)
+ return [padded_img], [[0, 0, img_width, img_height]]
+
+ img_height, img_width = img.shape[:2]
+ # If image is larger than or equal to target size, crop it into 640x640 tiles
+ crop_positions = []
+ count = 0
+ cropped_img_list = []
+ for top in range(0, img_height - over_lap, target_height - over_lap):
+ for left in range(0, img_width - over_lap, target_width - over_lap):
+ # Calculate the bottom and right boundaries for the crop
+ right = min(left + target_width, img_width)
+ bottom = min(top + target_height, img_height)
+ if right >= img_width:
+ right = img_width
+ left = max(0, right - target_width)
+ if bottom >= img_height:
+ bottom = img_height
+ top = max(0, bottom - target_height)
+ # Crop the image
+ cropped_img = img[top:bottom, left:right]
+ if bottom - top < target_height or right - left < target_width:
+ cropped_img = padding_image(cropped_img, size)
+ count += 1
+ cropped_img_list.append(cropped_img)
+
+ # Record the position of the cropped image
+ crop_positions.append([left, top, right, bottom])
+
+ return cropped_img_list, crop_positions
+
+
+def restore_preds(preds, crop_positions, original_size):
+
+ restored_pred = torch.zeros((1, 1, original_size[0], original_size[1]),
+ dtype=preds.dtype,
+ device=preds.device)
+ count = 0
+ for cropped_pred, (left, top, right, bottom) in zip(preds, crop_positions):
+
+ crop_height = bottom - top
+ crop_width = right - left
+
+ corp_vis_img = cropped_pred[:, :crop_height, :crop_width]
+ mask = corp_vis_img > 0.3
+ count += 1
+ restored_pred[:, :, top:top + crop_height, left:left +
+ crop_width] += mask[:, :crop_height, :crop_width].to(
+ preds.dtype)
+
+ return restored_pred
+
+
+def draw_det_res(dt_boxes, img, img_name, save_path):
+ src_im = img
+ for box in dt_boxes:
+ box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
+ cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ save_path = os.path.join(save_path, os.path.basename(img_name))
+ cv2.imwrite(save_path, src_im)
+
+
+def set_device(device, numId=0):
+ if device == 'gpu' and torch.cuda.is_available():
+ device = torch.device(f'cuda:{numId}')
+ else:
+ logger.info('GPU is not available, using CPU.')
+ device = torch.device('cpu')
+ return device
+
+
+class OpenDetector(object):
+
+ def __init__(self, config=None, numId=0):
+ """
+ 初始化函数。
+
+ Args:
+ config (dict, optional): 配置文件,默认为None。如果为None,则使用默认配置文件。
+ numId (int, optional): 设备编号,默认为0。
+
+ Returns:
+ None
+
+ Raises:
+ 无
+ """
+
+ if config is None:
+ config = Config(DEFAULT_CFG_PATH_DET).cfg
+
+ if not os.path.exists(config['Global']['pretrained_model']):
+ config['Global']['pretrained_model'] = check_and_download_model(
+ MODEL_NAME_DET, DOWNLOAD_URL_DET)
+
+ from opendet.modeling import build_model as build_det_model
+ from opendet.postprocess import build_post_process
+ from opendet.preprocess import create_operators, transform
+ self.transform = transform
+ global_config = config['Global']
+
+ # build model
+ self.model = build_det_model(config['Architecture'])
+ self.model.eval()
+ load_ckpt(self.model, config)
+ replace_batchnorm(self.model.backbone)
+ self.device = set_device(config['Global']['device'], numId=numId)
+ self.model.to(device=self.device)
+
+ # create data ops
+ transforms = []
+ for op in config['Eval']['dataset']['transforms']:
+ op_name = list(op)[0]
+ if 'Label' in op_name:
+ continue
+ elif op_name == 'KeepKeys':
+ op[op_name]['keep_keys'] = ['image', 'shape']
+ transforms.append(op)
+
+ self.ops = create_operators(transforms, global_config)
+
+ # build post process
+ self.post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+
+ def crop_infer(
+ self,
+ img_path=None,
+ img_numpy_list=None,
+ img_numpy=None,
+ ):
+ if img_numpy is not None:
+ img_numpy_list = [img_numpy]
+ num_img = 1
+ elif img_path is not None:
+ num_img = len(img_path)
+ elif img_numpy_list is not None:
+ num_img = len(img_numpy_list)
+ else:
+ raise Exception('No input image path or numpy array.')
+ results = []
+ for img_idx in range(num_img):
+ if img_numpy_list is not None:
+ img = img_numpy_list[img_idx]
+ data = {'image': img}
+ elif img_path is not None:
+ with open(img_path[img_idx], 'rb') as f:
+ img = f.read()
+ data = {'image': img}
+ data = self.transform(data, self.ops[:1])
+ src_img_ori = data['image']
+ img_height, img_width = src_img_ori.shape[:2]
+
+ target_size = 640
+ over_lap = 64
+ if img_height > img_width:
+ r_h = target_size * 2 - over_lap
+ r_w = img_width * (target_size * 2 - over_lap) // img_height
+ else:
+ r_w = target_size * 2 - over_lap
+ r_h = img_height * (target_size * 2 - over_lap) // img_width
+ src_img = cv2.resize(src_img_ori, (r_w, r_h))
+ shape_list_ori = np.array([[
+ img_height, img_width,
+ float(r_h) / img_height,
+ float(r_w) / img_width
+ ]])
+ img_height, img_width = src_img.shape[:2]
+ cropped_img_list, crop_positions = resize_image(src_img,
+ size=(target_size,
+ target_size),
+ over_lap=over_lap)
+
+ image_list = []
+ shape_list = []
+ for img in cropped_img_list:
+ batch_i = self.transform({'image': img}, self.ops[-3:-1])
+ image_list.append(batch_i['image'])
+ shape_list.append([640, 640, 1, 1])
+ images = np.array(image_list)
+ shape_list = np.array(shape_list)
+ images = torch.from_numpy(images).to(device=self.device)
+ with torch.no_grad():
+ t_start = time.time()
+ preds = self.model(images)
+ t_cost = time.time() - t_start
+
+ preds['maps'] = restore_preds(preds['maps'], crop_positions,
+ (img_height, img_width))
+ post_result = self.post_process_class(preds, shape_list_ori)
+ info = {'boxes': post_result[0]['points'], 'elapse': t_cost}
+ results.append(info)
+ return results
+
+ def __call__(self,
+ img_path=None,
+ img_numpy_list=None,
+ img_numpy=None,
+ return_mask=False):
+ """
+ 对输入图像进行处理,并返回处理结果。
+
+ Args:
+ img_path (str, optional): 图像文件路径。默认为 None。
+ img_numpy_list (list, optional): 图像数据列表,每个元素为 numpy 数组。默认为 None。
+ img_numpy (numpy.ndarray, optional): 图像数据,numpy 数组格式。默认为 None。
+
+ Returns:
+ list: 包含处理结果的列表。每个元素为一个字典,包含 'boxes' 和 'elapse' 两个键。
+ 'boxes' 的值为检测到的目标框点集,'elapse' 的值为处理时间。
+
+ Raises:
+ Exception: 若没有提供图像路径或 numpy 数组,则抛出异常。
+
+ """
+
+ if img_numpy is not None:
+ img_numpy_list = [img_numpy]
+ num_img = 1
+ elif img_path is not None:
+ img_path = get_image_file_list(img_path)
+ num_img = len(img_path)
+ elif img_numpy_list is not None:
+ num_img = len(img_numpy_list)
+ else:
+ raise Exception('No input image path or numpy array.')
+ results = []
+ for img_idx in range(num_img):
+ if img_numpy_list is not None:
+ img = img_numpy_list[img_idx]
+ data = {'image': img}
+ elif img_path is not None:
+ with open(img_path[img_idx], 'rb') as f:
+ img = f.read()
+ data = {'image': img}
+ data = self.transform(data, self.ops[:1])
+ batch = self.transform(data, self.ops[1:])
+
+ images = np.expand_dims(batch[0], axis=0)
+ shape_list = np.expand_dims(batch[1], axis=0)
+ images = torch.from_numpy(images).to(device=self.device)
+ with torch.no_grad():
+ t_start = time.time()
+ preds = self.model(images)
+ t_cost = time.time() - t_start
+ post_result = self.post_process_class(preds, shape_list)
+
+ info = {'boxes': post_result[0]['points'], 'elapse': t_cost}
+ if return_mask:
+ if isinstance(preds['maps'], torch.Tensor):
+ mask = preds['maps'].detach().cpu().numpy()
+ else:
+ mask = preds['maps']
+ info['mask'] = mask
+ results.append(info)
+ return results
+
+
+@torch.no_grad()
+def main(cfg):
+ is_visualize = cfg['Global'].get('is_visualize', False)
+ model = OpenDetector(cfg)
+
+ save_res_path = './det_results/'
+ if not os.path.exists(save_res_path):
+ os.makedirs(save_res_path)
+ sample_num = 0
+ with open(save_res_path + '/det_results.txt', 'wb') as fout:
+ for file in get_image_file_list(cfg['Global']['infer_img']):
+ preds_result = model(img_path=file)[0]
+ logger.info('{} infer_img: {}, time cost: {}'.format(
+ sample_num, file, preds_result['elapse']))
+ boxes = preds_result['boxes']
+ dt_boxes_json = []
+ for box in boxes:
+ tmp_json = {}
+ tmp_json['points'] = np.array(box).tolist()
+ dt_boxes_json.append(tmp_json)
+ if is_visualize:
+ src_img = cv2.imread(file)
+ draw_det_res(boxes, src_img, file, save_res_path)
+ logger.info('The detected Image saved in {}'.format(
+ os.path.join(save_res_path, os.path.basename(file))))
+ otstr = file + '\t' + json.dumps(dt_boxes_json) + '\n'
+ logger.info('results: {}'.format(json.dumps(dt_boxes_json)))
+ fout.write(otstr.encode())
+ sample_num += 1
+ logger.info(
+ f"Results saved to {os.path.join(save_res_path, 'det_results.txt')}.)"
+ )
+
+ logger.info('success!')
+
+
+if __name__ == '__main__':
+ FLAGS = ArgsParser().parse_args()
+ cfg = Config(FLAGS.config)
+ FLAGS = vars(FLAGS)
+ opt = FLAGS.pop('opt')
+ cfg.merge_dict(FLAGS)
+ cfg.merge_dict(opt)
+ main(cfg.cfg)
diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py
new file mode 100644
index 0000000000000000000000000000000000000000..90c567232501198c828e1db53353ec346f74c2de
--- /dev/null
+++ b/tools/infer_e2e.py
@@ -0,0 +1,463 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from pathlib import Path
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
+import argparse
+import numpy as np
+import copy
+import time
+import cv2
+import json
+from PIL import Image
+from tools.utils.utility import get_image_file_list, check_and_read
+from tools.infer_rec import OpenRecognizer
+from tools.infer_det import OpenDetector
+from tools.engine import Config
+from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop, draw_ocr_box_txt
+from tools.utils.logging import get_logger
+
+root_dir = Path(__file__).resolve().parent
+DEFAULT_CFG_PATH_DET = str(root_dir / '../configs/det/dbnet/repvit_db.yml')
+DEFAULT_CFG_PATH_REC_SERVER = str(root_dir /
+ '../configs/det/svtrv2/svtrv2_ch.yml')
+DEFAULT_CFG_PATH_REC = str(root_dir / '../configs/rec/svtrv2/repsvtr_ch.yml')
+
+logger = get_logger()
+
+MODEL_NAME_DET = './openocr_det_repvit_ch.pth' # 模型文件名称
+DOWNLOAD_URL_DET = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth' # 模型文件 URL
+MODEL_NAME_REC = './openocr_repsvtr_ch.pth' # 模型文件名称
+DOWNLOAD_URL_REC = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth' # 模型文件 URL
+MODEL_NAME_REC_SERVER = './openocr_svtrv2_ch.pth' # 模型文件名称
+DOWNLOAD_URL_REC_SERVER = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth' # 模型文件 URL
+
+
+def check_and_download_model(model_name: str, url: str):
+ """
+ 检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。
+
+ Args:
+ model_name (str): 模型文件的名称,例如 "model.pt"
+ url (str): 模型文件的下载地址
+
+ Returns:
+ str: 模型文件的完整路径
+ """
+ if os.path.exists(model_name):
+ return model_name
+
+ # 固定缓存路径为用户主目录下的 ".cache/openocr"
+ cache_dir = Path.home() / '.cache' / 'openocr'
+ model_path = cache_dir / model_name
+
+ # 如果模型文件已存在,直接返回路径
+ if model_path.exists():
+ logger.info(f'Model already exists at: {model_path}')
+ return str(model_path)
+
+ # 如果文件不存在,下载模型
+ logger.info(f'Model not found. Downloading from {url}...')
+
+ # 创建缓存目录(如果不存在)
+ cache_dir.mkdir(parents=True, exist_ok=True)
+
+ try:
+ # 下载文件
+ import urllib.request
+ with urllib.request.urlopen(url) as response, open(model_path,
+ 'wb') as out_file:
+ out_file.write(response.read())
+ logger.info(f'Model downloaded and saved at: {model_path}')
+ return str(model_path)
+
+ except Exception as e:
+ logger.error(f'Error downloading the model: {e}')
+ # 提示用户手动下载
+ logger.error(
+ f'Unable to download the model automatically. '
+ f'Please download the model manually from the following URL:\n{url}\n'
+ f'and save it to: {model_name} or {model_path}')
+ raise RuntimeError(
+ f'Failed to download the model. Please download it manually from {url} '
+ f'and save it to {model_path}') from e
+
+
+def check_and_download_font(font_path):
+ if not os.path.exists(font_path):
+ cache_dir = Path.home() / '.cache' / 'openocr'
+ font_path = str(cache_dir / font_path)
+ if os.path.exists(font_path):
+ return font_path
+ logger.info(f"Downloading '{font_path}' ...")
+ try:
+ import urllib.request
+ font_url = 'https://shuiche-shop.oss-cn-chengdu.aliyuncs.com/fonts/simfang.ttf'
+ urllib.request.urlretrieve(font_url, font_path)
+ logger.info(f'Downloading font success: {font_path}')
+ except Exception as e:
+ logger.info(f'Downloading font error: {e}')
+ return font_path
+
+
+def sorted_boxes(dt_boxes):
+ """
+ Sort text boxes in order from top to bottom, left to right
+ args:
+ dt_boxes(array):detected text boxes with shape [4, 2]
+ return:
+ sorted boxes(array) with shape [4, 2]
+ """
+ num_boxes = dt_boxes.shape[0]
+ sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
+ _boxes = list(sorted_boxes)
+
+ for i in range(num_boxes - 1):
+ for j in range(i, -1, -1):
+ if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
+ _boxes[j + 1][0][0] < _boxes[j][0][0]):
+ tmp = _boxes[j]
+ _boxes[j] = _boxes[j + 1]
+ _boxes[j + 1] = tmp
+ else:
+ break
+ return _boxes
+
+
+class OpenOCR(object):
+
+ def __init__(self, mode='mobile', drop_score=0.5, det_box_type='quad'):
+ """
+ 初始化函数,用于初始化OCR引擎的相关配置和组件。
+
+ Args:
+ mode (str, optional): 运行模式,可选值为'mobile'或'server'。默认为'mobile'。
+ drop_score (float, optional): 检测框的置信度阈值,低于该阈值的检测框将被丢弃。默认为0.5。
+ det_box_type (str, optional): 检测框的类型,可选值为'quad' and 'poly'。默认为'quad'。
+
+ Returns:
+ 无返回值。
+
+ """
+ cfg_det = Config(DEFAULT_CFG_PATH_DET).cfg # mobile model
+ model_dir = check_and_download_model(MODEL_NAME_DET, DOWNLOAD_URL_DET)
+ cfg_det['Global']['pretrained_model'] = model_dir
+ if mode == 'server':
+ cfg_rec = Config(DEFAULT_CFG_PATH_REC_SERVER).cfg # server model
+ model_dir = check_and_download_model(MODEL_NAME_REC_SERVER,
+ DOWNLOAD_URL_REC_SERVER)
+ else:
+ cfg_rec = Config(DEFAULT_CFG_PATH_REC).cfg # mobile model
+ model_dir = check_and_download_model(MODEL_NAME_REC,
+ DOWNLOAD_URL_REC)
+ cfg_rec['Global']['pretrained_model'] = model_dir
+ self.text_detector = OpenDetector(cfg_det)
+ self.text_recognizer = OpenRecognizer(cfg_rec)
+ self.det_box_type = det_box_type
+ self.drop_score = drop_score
+
+ self.crop_image_res_index = 0
+
+ def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
+ os.makedirs(output_dir, exist_ok=True)
+ bbox_num = len(img_crop_list)
+ for bno in range(bbox_num):
+ cv2.imwrite(
+ os.path.join(output_dir,
+ f'mg_crop_{bno+self.crop_image_res_index}.jpg'),
+ img_crop_list[bno],
+ )
+ self.crop_image_res_index += bbox_num
+
+ def infer_single_image(self,
+ img_numpy,
+ ori_img,
+ crop_infer=False,
+ rec_batch_num=6,
+ return_mask=False):
+ start = time.time()
+ if crop_infer:
+ dt_boxes = self.text_detector.crop_infer(
+ img_numpy=img_numpy)[0]['boxes']
+ else:
+ det_res = self.text_detector(img_numpy=img_numpy,
+ return_mask=return_mask)[0]
+ dt_boxes = det_res['boxes']
+ # logger.info(dt_boxes)
+ det_time_cost = time.time() - start
+
+ if dt_boxes is None:
+ return None, None, None
+
+ img_crop_list = []
+
+ dt_boxes = sorted_boxes(dt_boxes)
+
+ for bno in range(len(dt_boxes)):
+ tmp_box = np.array(copy.deepcopy(dt_boxes[bno])).astype(np.float32)
+ if self.det_box_type == 'quad':
+ img_crop = get_rotate_crop_image(ori_img, tmp_box)
+ else:
+ img_crop = get_minarea_rect_crop(ori_img, tmp_box)
+ img_crop_list.append(img_crop)
+
+ start = time.time()
+ rec_res = self.text_recognizer(img_numpy_list=img_crop_list,
+ batch_num=rec_batch_num)
+ rec_time_cost = time.time() - start
+
+ filter_boxes, filter_rec_res = [], []
+ rec_time_cost_sig = 0.0
+ for box, rec_result in zip(dt_boxes, rec_res):
+ text, score = rec_result['text'], rec_result['score']
+ rec_time_cost_sig += rec_result['elapse']
+ if score >= self.drop_score:
+ filter_boxes.append(box)
+ filter_rec_res.append([text, score])
+
+ avg_rec_time_cost = rec_time_cost_sig / len(dt_boxes) if len(
+ dt_boxes) > 0 else 0.0
+ if return_mask:
+ return filter_boxes, filter_rec_res, {
+ 'time_cost': det_time_cost + rec_time_cost,
+ 'detection_time': det_time_cost,
+ 'recognition_time': rec_time_cost,
+ 'avg_rec_time_cost': avg_rec_time_cost
+ }, det_res['mask']
+
+ return filter_boxes, filter_rec_res, {
+ 'time_cost': det_time_cost + rec_time_cost,
+ 'detection_time': det_time_cost,
+ 'recognition_time': rec_time_cost,
+ 'avg_rec_time_cost': avg_rec_time_cost
+ }
+
+ def __call__(self,
+ img_path=None,
+ save_dir='e2e_results/',
+ is_visualize=False,
+ img_numpy=None,
+ rec_batch_num=6,
+ crop_infer=False,
+ return_mask=False):
+ """
+ img_path: str, optional, default=None
+ Path to the directory containing images or the image filename.
+ save_dir: str, optional, default='e2e_results/'
+ Directory to save prediction and visualization results. Defaults to a subfolder in img_path.
+ is_visualize: bool, optional, default=False
+ Visualize the results.
+ img_numpy: numpy or list[numpy], optional, default=None
+ numpy of an image or List of numpy arrays representing images.
+ rec_batch_num: int, optional, default=6
+ Batch size for text recognition.
+ crop_infer: bool, optional, default=False
+ Whether to use crop inference.
+ """
+
+ if img_numpy is None and img_path is None:
+ raise ValueError('img_path and img_numpy cannot be both None.')
+ if img_numpy is not None:
+ if not isinstance(img_numpy, list):
+ img_numpy = [img_numpy]
+ results = []
+ time_dicts = []
+ for index, img in enumerate(img_numpy):
+ ori_img = img.copy()
+ if return_mask:
+ dt_boxes, rec_res, time_dict, mask = self.infer_single_image(
+ img_numpy=img,
+ ori_img=ori_img,
+ crop_infer=crop_infer,
+ rec_batch_num=rec_batch_num,
+ return_mask=return_mask)
+ else:
+ dt_boxes, rec_res, time_dict = self.infer_single_image(
+ img_numpy=img,
+ ori_img=ori_img,
+ crop_infer=crop_infer,
+ rec_batch_num=rec_batch_num)
+ if dt_boxes is None:
+ results.append([])
+ time_dicts.append({})
+ continue
+ res = [{
+ 'transcription': rec_res[i][0],
+ 'points': np.array(dt_boxes[i]).tolist(),
+ 'score': rec_res[i][1],
+ } for i in range(len(dt_boxes))]
+ results.append(res)
+ time_dicts.append(time_dict)
+ if return_mask:
+ return results, time_dicts, mask
+ return results, time_dicts
+
+ image_file_list = get_image_file_list(img_path)
+ save_results = []
+ time_dicts_return = []
+ for idx, image_file in enumerate(image_file_list):
+ img, flag_gif, flag_pdf = check_and_read(image_file)
+ if not flag_gif and not flag_pdf:
+ img = cv2.imread(image_file)
+ if not flag_pdf:
+ if img is None:
+ return None
+ imgs = [img]
+ else:
+ imgs = img
+ logger.info(
+ f'Processing {idx+1}/{len(image_file_list)}: {image_file}')
+
+ res_list = []
+ time_dicts = []
+ for index, img_numpy in enumerate(imgs):
+ ori_img = img_numpy.copy()
+ dt_boxes, rec_res, time_dict = self.infer_single_image(
+ img_numpy=img_numpy,
+ ori_img=ori_img,
+ crop_infer=crop_infer,
+ rec_batch_num=rec_batch_num)
+ if dt_boxes is None:
+ res_list.append([])
+ time_dicts.append({})
+ continue
+ res = [{
+ 'transcription': rec_res[i][0],
+ 'points': np.array(dt_boxes[i]).tolist(),
+ 'score': rec_res[i][1],
+ } for i in range(len(dt_boxes))]
+ res_list.append(res)
+ time_dicts.append(time_dict)
+
+ for index, (res, time_dict) in enumerate(zip(res_list,
+ time_dicts)):
+
+ if len(res) > 0:
+ logger.info(f'Results: {res}.')
+ logger.info(f'Time cost: {time_dict}.')
+ else:
+ logger.info('No text detected.')
+
+ if len(res_list) > 1:
+ save_pred = (os.path.basename(image_file) + '_' +
+ str(index) + '\t' +
+ json.dumps(res, ensure_ascii=False) + '\n')
+ else:
+ if len(res) > 0:
+ save_pred = (os.path.basename(image_file) + '\t' +
+ json.dumps(res, ensure_ascii=False) +
+ '\n')
+ else:
+ continue
+ save_results.append(save_pred)
+ time_dicts_return.append(time_dict)
+
+ if is_visualize and len(res) > 0:
+ if idx == 0:
+ font_path = './simfang.ttf'
+ font_path = check_and_download_font(font_path)
+ os.makedirs(save_dir, exist_ok=True)
+ draw_img_save_dir = os.path.join(
+ save_dir, 'vis_results/')
+ os.makedirs(draw_img_save_dir, exist_ok=True)
+ logger.info(
+ f'Visualized results will be saved to {draw_img_save_dir}.'
+ )
+ dt_boxes = [res[i]['points'] for i in range(len(res))]
+ rec_res = [
+ res[i]['transcription'] for i in range(len(res))
+ ]
+ rec_score = [res[i]['score'] for i in range(len(res))]
+ image = Image.fromarray(
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+ boxes = dt_boxes
+ txts = [rec_res[i] for i in range(len(rec_res))]
+ scores = [rec_score[i] for i in range(len(rec_res))]
+
+ draw_img = draw_ocr_box_txt(
+ image,
+ boxes,
+ txts,
+ scores,
+ drop_score=self.drop_score,
+ font_path=font_path,
+ )
+ if flag_gif:
+ save_file = image_file[:-3] + 'png'
+ elif flag_pdf:
+ save_file = image_file.replace(
+ '.pdf', '_' + str(index) + '.png')
+ else:
+ save_file = image_file
+ cv2.imwrite(
+ os.path.join(draw_img_save_dir,
+ os.path.basename(save_file)),
+ draw_img[:, :, ::-1],
+ )
+
+ if save_results:
+ os.makedirs(save_dir, exist_ok=True)
+ with open(os.path.join(save_dir, 'system_results.txt'),
+ 'w',
+ encoding='utf-8') as f:
+ f.writelines(save_results)
+ logger.info(
+ f"Results saved to {os.path.join(save_dir, 'system_results.txt')}."
+ )
+ if is_visualize:
+ logger.info(
+ f'Visualized results saved to {draw_img_save_dir}.')
+ return save_results, time_dicts_return
+ else:
+ logger.info('No text detected.')
+ return None, None
+
+
+def main():
+ parser = argparse.ArgumentParser(description='OpenOCR system')
+ parser.add_argument(
+ '--img_path',
+ type=str,
+ help='Path to the directory containing images or the image filename.')
+ parser.add_argument(
+ '--mode',
+ type=str,
+ default='mobile',
+ help="Mode of the OCR system, e.g., 'mobile' or 'server'.")
+ parser.add_argument(
+ '--save_dir',
+ type=str,
+ default='e2e_results/',
+ help='Directory to save prediction and visualization results. \
+ Defaults to ./e2e_results/.')
+ parser.add_argument('--is_vis',
+ action='store_true',
+ default=False,
+ help='Visualize the results.')
+ parser.add_argument('--drop_score',
+ type=float,
+ default=0.5,
+ help='Score threshold for text recognition.')
+ args = parser.parse_args()
+
+ img_path = args.img_path
+ mode = args.mode
+ save_dir = args.save_dir
+ is_visualize = args.is_vis
+ drop_score = args.drop_score
+
+ text_sys = OpenOCR(mode=mode, drop_score=drop_score,
+ det_box_type='quad') # det_box_type: 'quad' or 'poly'
+ text_sys(img_path=img_path, save_dir=save_dir, is_visualize=is_visualize)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/infer_e2e_parallel.py b/tools/infer_e2e_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..f13a5e6d3ff17961a727e6950d7d74739542630f
--- /dev/null
+++ b/tools/infer_e2e_parallel.py
@@ -0,0 +1,184 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+import queue
+import os
+import sys
+import time
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+import numpy as np
+import cv2
+import json
+from PIL import Image
+from tools.utils.utility import get_image_file_list, check_and_read
+from tools.infer_rec import OpenRecognizer
+from tools.infer_det import OpenDetector
+from tools.infer_e2e import check_and_download_font, sorted_boxes
+from tools.engine import Config
+from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop, draw_ocr_box_txt
+
+
+class OpenOCRParallel:
+
+ def __init__(self, drop_score=0.5, det_box_type='quad', max_rec_threads=1):
+ cfg_det = Config(
+ './configs/det/dbnet/repvit_db.yml').cfg # mobile model
+ # cfg_rec = Config('./configs/rec/svtrv2/svtrv2_ch.yml').cfg # server model
+ cfg_rec = Config(
+ './configs/rec/svtrv2/repsvtr_ch.yml').cfg # mobile model
+ self.text_detector = OpenDetector(cfg_det, numId=0)
+ self.text_recognizer = OpenRecognizer(cfg_rec, numId=0)
+ self.det_box_type = det_box_type
+ self.drop_score = drop_score
+ self.queue = queue.Queue(
+ ) # Queue to hold detected boxes for recognition
+ self.results = {}
+ self.lock = threading.Lock() # Lock for thread-safe access to results
+ self.max_rec_threads = max_rec_threads
+ self.stop_signal = threading.Event() # Signal to stop threads
+
+ def start_recognition_threads(self):
+ """Start recognition threads."""
+ self.rec_threads = []
+ for _ in range(self.max_rec_threads):
+ t = threading.Thread(target=self.recognize_text)
+ t.start()
+ self.rec_threads.append(t)
+
+ def detect_text(self, image_list):
+ """Single-threaded text detection for all images."""
+ for image_id, (img_numpy, ori_img) in enumerate(image_list):
+ dt_boxes = self.text_detector(img_numpy=img_numpy)[0]['boxes']
+ if dt_boxes is None:
+ self.results[image_id] = [] # If no boxes, set empty results
+ continue
+
+ dt_boxes = sorted_boxes(dt_boxes)
+ img_crop_list = []
+ for box in dt_boxes:
+ tmp_box = np.array(box).astype(np.float32)
+ img_crop = (get_rotate_crop_image(ori_img, tmp_box)
+ if self.det_box_type == 'quad' else
+ get_minarea_rect_crop(ori_img, tmp_box))
+ img_crop_list.append(img_crop)
+ self.queue.put(
+ (image_id, dt_boxes, img_crop_list
+ )) # Put image ID, detected box, and cropped image in queue
+
+ # Signal that no more items will be added to the queue
+ self.stop_signal.set()
+
+ def recognize_text(self):
+ """Recognize text in each cropped image."""
+ while not self.stop_signal.is_set() or not self.queue.empty():
+ try:
+ image_id, boxs, img_crop_list = self.queue.get(timeout=0.5)
+ rec_results = self.text_recognizer(
+ img_numpy_list=img_crop_list, batch_num=6)
+ for rec_result, box in zip(rec_results, boxs):
+ text, score = rec_result['text'], rec_result['score']
+ if score >= self.drop_score:
+ with self.lock:
+ # Ensure results dictionary has a list for each image ID
+ if image_id not in self.results:
+ self.results[image_id] = []
+ self.results[image_id].append({
+ 'transcription':
+ text,
+ 'points':
+ box.tolist(),
+ 'score':
+ score
+ })
+ self.queue.task_done()
+ except queue.Empty:
+ continue
+
+ def process_images(self, image_list):
+ """Process a list of images."""
+ # Initialize results dictionary
+ self.results = {i: [] for i in range(len(image_list))}
+
+ # Start recognition threads
+ t_start_1 = time.time()
+ self.start_recognition_threads()
+
+ # Start detection in the main thread
+ t_start = time.time()
+ self.detect_text(image_list)
+ print('det time:', time.time() - t_start)
+
+ # Wait for recognition threads to finish
+ for t in self.rec_threads:
+ t.join()
+ self.stop_signal.clear()
+ print('all time:', time.time() - t_start_1)
+ return self.results
+
+
+def main(cfg_det, cfg_rec):
+ img_path = './testA/'
+ image_file_list = get_image_file_list(img_path)
+ drop_score = 0.5
+ text_sys = OpenOCRParallel(
+ drop_score=drop_score,
+ det_box_type='quad') # det_box_type: 'quad' or 'poly'
+ is_visualize = False
+ if is_visualize:
+ font_path = './simfang.ttf'
+ check_and_download_font(font_path)
+ draw_img_save_dir = img_path + 'e2e_results/' if img_path[
+ -1] != '/' else img_path[:-1] + 'e2e_results/'
+ os.makedirs(draw_img_save_dir, exist_ok=True)
+ save_results = []
+
+ # Prepare images
+ images = []
+ t_start = time.time()
+ for image_file in image_file_list:
+ img, flag_gif, flag_pdf = check_and_read(image_file)
+ if not flag_gif and not flag_pdf:
+ img = cv2.imread(image_file)
+ if img is not None:
+ images.append((img, img.copy()))
+
+ results = text_sys.process_images(images)
+ print(f'time cost: {time.time() - t_start}')
+ # Save results and visualize
+ for image_id, res in results.items():
+ image_file = image_file_list[image_id]
+ save_pred = f'{os.path.basename(image_file)}\t{json.dumps(res, ensure_ascii=False)}\n'
+ # print(save_pred)
+ save_results.append(save_pred)
+
+ if is_visualize:
+ dt_boxes = [result['points'] for result in res]
+ rec_res = [result['transcription'] for result in res]
+ rec_score = [result['score'] for result in res]
+ image = Image.fromarray(
+ cv2.cvtColor(images[image_id][0], cv2.COLOR_BGR2RGB))
+ draw_img = draw_ocr_box_txt(image,
+ dt_boxes,
+ rec_res,
+ rec_score,
+ drop_score=drop_score,
+ font_path=font_path)
+
+ save_file = os.path.join(draw_img_save_dir,
+ os.path.basename(image_file))
+ cv2.imwrite(save_file, draw_img[:, :, ::-1])
+
+ with open(os.path.join(draw_img_save_dir, 'system_results.txt'),
+ 'w',
+ encoding='utf-8') as f:
+ f.writelines(save_results)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/infer_rec.py b/tools/infer_rec.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdd8eec381a650879c4694465a831157c779a5d2
--- /dev/null
+++ b/tools/infer_rec.py
@@ -0,0 +1,393 @@
+import os
+from pathlib import Path
+import sys
+import time
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+import numpy as np
+import torch
+from torchvision import transforms as T
+from torchvision.transforms import functional as F
+from tools.engine import Config
+from tools.utility import ArgsParser
+from tools.utils.ckpt import load_ckpt
+from tools.utils.logging import get_logger
+from tools.utils.utility import get_image_file_list
+from tools.infer_det import replace_batchnorm
+
+logger = get_logger()
+
+root_dir = Path(__file__).resolve().parent
+DEFAULT_CFG_PATH_REC_SERVER = str(root_dir /
+ '../configs/det/svtrv2/svtrv2_ch.yml')
+DEFAULT_CFG_PATH_REC = str(root_dir / '../configs/rec/svtrv2/repsvtr_ch.yml')
+DEFAULT_DICT_PATH_REC = str(root_dir / './utils/ppocr_keys_v1.txt')
+
+MODEL_NAME_REC = './openocr_repsvtr_ch.pth' # 模型文件名称
+DOWNLOAD_URL_REC = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth' # 模型文件 URL
+MODEL_NAME_REC_SERVER = './openocr_svtrv2_ch.pth' # 模型文件名称
+DOWNLOAD_URL_REC_SERVER = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth' # 模型文件 URL
+
+
+def check_and_download_model(model_name: str, url: str):
+ """
+ 检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。
+
+ Args:
+ model_name (str): 模型文件的名称,例如 "model.pt"
+ url (str): 模型文件的下载地址
+
+ Returns:
+ str: 模型文件的完整路径
+ """
+ if os.path.exists(model_name):
+ return model_name
+
+ # 固定缓存路径为用户主目录下的 ".cache/openocr"
+ cache_dir = Path.home() / '.cache' / 'openocr'
+ model_path = cache_dir / model_name
+
+ # 如果模型文件已存在,直接返回路径
+ if model_path.exists():
+ logger.info(f'Model already exists at: {model_path}')
+ return str(model_path)
+
+ # 如果文件不存在,下载模型
+ logger.info(f'Model not found. Downloading from {url}...')
+
+ # 创建缓存目录(如果不存在)
+ cache_dir.mkdir(parents=True, exist_ok=True)
+
+ try:
+ # 下载文件
+ import urllib.request
+ with urllib.request.urlopen(url) as response, open(model_path,
+ 'wb') as out_file:
+ out_file.write(response.read())
+ logger.info(f'Model downloaded and saved at: {model_path}')
+ return str(model_path)
+
+ except Exception as e:
+ logger.error(f'Error downloading the model: {e}')
+ # 提示用户手动下载
+ logger.error(
+ f'Unable to download the model automatically. '
+ f'Please download the model manually from the following URL:\n{url}\n'
+ f'and save it to: {model_name} or {model_path}')
+ raise RuntimeError(
+ f'Failed to download the model. Please download it manually from {url} '
+ f'and save it to {model_path}') from e
+
+
+class RatioRecTVReisze(object):
+
+ def __init__(self, cfg):
+ self.max_ratio = cfg['Eval']['loader'].get('max_ratio', 12)
+ self.base_shape = cfg['Eval']['dataset'].get(
+ 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
+ self.base_h = cfg['Eval']['dataset'].get('base_h', 32)
+ self.interpolation = T.InterpolationMode.BICUBIC
+ transforms = []
+ transforms.extend([
+ T.ToTensor(),
+ T.Normalize(0.5, 0.5),
+ ])
+ self.transforms = T.Compose(transforms)
+ self.ceil = cfg['Eval']['dataset'].get('ceil', False),
+
+ def __call__(self, data):
+ img = data['image']
+ imgH = self.base_h
+ w, h = img.size
+ if self.ceil:
+ gen_ratio = int(float(w) / float(h)) + 1
+ else:
+ gen_ratio = max(1, round(float(w) / float(h)))
+ ratio_resize = min(gen_ratio, self.max_ratio)
+ imgW, imgH = self.base_shape[ratio_resize -
+ 1] if ratio_resize <= 4 else [
+ self.base_h *
+ ratio_resize, self.base_h
+ ]
+ resized_w = imgW
+ resized_image = F.resize(img, (imgH, resized_w),
+ interpolation=self.interpolation)
+ img = self.transforms(resized_image)
+ data['image'] = img
+ return data
+
+
+def build_rec_process(cfg):
+ transforms = []
+ ratio_resize_flag = True
+ for op in cfg['Eval']['dataset']['transforms']:
+ op_name = list(op)[0]
+ if 'Resize' in op_name:
+ ratio_resize_flag = False
+ if 'Label' in op_name:
+ continue
+ elif op_name in ['RecResizeImg']:
+ op[op_name]['infer_mode'] = True
+ elif op_name == 'KeepKeys':
+ if cfg['Architecture']['algorithm'] in ['SAR', 'RobustScanner']:
+ if 'valid_ratio' in op[op_name]['keep_keys']:
+ op[op_name]['keep_keys'] = ['image', 'valid_ratio']
+ else:
+ op[op_name]['keep_keys'] = ['image']
+ else:
+ op[op_name]['keep_keys'] = ['image']
+ transforms.append(op)
+ return transforms, ratio_resize_flag
+
+
+def set_device(device, numId=0):
+ if device == 'gpu' and torch.cuda.is_available():
+ device = torch.device(f'cuda:{numId}')
+ else:
+ logger.info('GPU is not available, using CPU.')
+ device = torch.device('cpu')
+ return device
+
+
+class OpenRecognizer(object):
+
+ def __init__(self, config=None, mode='mobile', numId=0):
+ """
+ 初始化方法。
+
+ Args:
+ config (dict, optional): 配置信息。默认为None。
+ mode (str, optional): 模式,'server' 或 'mobile'。默认为'mobile'。
+ numId (int, optional): 设备编号。默认为0。
+
+ Returns:
+ None
+
+ Raises:
+ 无
+
+ """
+ if config is None:
+ if mode == 'server':
+ config = Config(
+ DEFAULT_CFG_PATH_REC_SERVER).cfg # server model
+ if not os.path.exists(config['Global']['pretrained_model']):
+ model_dir = check_and_download_model(
+ MODEL_NAME_REC_SERVER, DOWNLOAD_URL_REC_SERVER)
+ else:
+ config = Config(DEFAULT_CFG_PATH_REC).cfg # mobile model
+ if not os.path.exists(config['Global']['pretrained_model']):
+ model_dir = check_and_download_model(
+ MODEL_NAME_REC, DOWNLOAD_URL_REC)
+ config['Global']['pretrained_model'] = model_dir
+ config['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC
+ else:
+ if config['Architecture']['algorithm'] == 'SVTRv2_mobile':
+ if not os.path.exists(config['Global']['pretrained_model']):
+ config['Global'][
+ 'pretrained_model'] = check_and_download_model(
+ MODEL_NAME_REC, DOWNLOAD_URL_REC)
+ config['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC
+ elif config['Architecture']['algorithm'] == 'SVTRv2_server':
+ if not os.path.exists(config['Global']['pretrained_model']):
+ config['Global'][
+ 'pretrained_model'] = check_and_download_model(
+ MODEL_NAME_REC_SERVER, DOWNLOAD_URL_REC_SERVER)
+ config['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC
+ global_config = config['Global']
+ self.cfg = config
+ if global_config['pretrained_model'] is None:
+ global_config[
+ 'pretrained_model'] = global_config['output_dir'] + '/best.pth'
+ # build post process
+ from openrec.modeling import build_model as build_rec_model
+ from openrec.postprocess import build_post_process
+ from openrec.preprocess import create_operators, transform
+ self.transform = transform
+ self.post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+ char_num = self.post_process_class.get_character_num()
+ config['Architecture']['Decoder']['out_channels'] = char_num
+ # print(char_num)
+ self.model = build_rec_model(config['Architecture'])
+ load_ckpt(self.model, config)
+
+ # exit(0)
+ self.device = set_device(global_config['device'], numId=numId)
+ self.model.eval()
+ replace_batchnorm(self.model.encoder)
+ self.model.to(device=self.device)
+
+ transforms, ratio_resize_flag = build_rec_process(self.cfg)
+ global_config['infer_mode'] = True
+ self.ops = create_operators(transforms, global_config)
+ if ratio_resize_flag:
+ ratio_resize = RatioRecTVReisze(cfg=self.cfg)
+ self.ops.insert(-1, ratio_resize)
+
+ def __call__(self,
+ img_path=None,
+ img_numpy_list=None,
+ img_numpy=None,
+ batch_num=1):
+ """
+ 调用函数,处理输入图像,并返回识别结果。
+
+ Args:
+ img_path (str, optional): 图像文件的路径。默认为 None。
+ img_numpy_list (list, optional): 包含多个图像 numpy 数组的列表。默认为 None。
+ img_numpy (numpy.ndarray, optional): 单个图像的 numpy 数组。默认为 None。
+ batch_num (int, optional): 每次处理的图像数量。默认为 1。
+
+ Returns:
+ list: 包含识别结果的列表,每个元素为一个字典,包含文件路径(如果有的话)、文本、分数和延迟时间。
+
+ Raises:
+ Exception: 如果没有提供图像路径或 numpy 数组,则引发异常。
+ """
+
+ if img_numpy is not None:
+ img_numpy_list = [img_numpy]
+ num_img = 1
+ elif img_path is not None:
+ img_path = get_image_file_list(img_path)
+ num_img = len(img_path)
+ elif img_numpy_list is not None:
+ num_img = len(img_numpy_list)
+ else:
+ raise Exception('No input image path or numpy array.')
+ results = []
+ for start_idx in range(0, num_img, batch_num):
+ batch_data = []
+ batch_others = []
+ batch_file_names = []
+
+ max_width, max_height = 0, 0
+ # Prepare batch data
+ for img_idx in range(start_idx, min(start_idx + batch_num,
+ num_img)):
+ if img_numpy_list is not None:
+ img = img_numpy_list[img_idx]
+ data = {'image': img}
+ elif img_path is not None:
+ file_name = img_path[img_idx]
+ with open(file_name, 'rb') as f:
+ img = f.read()
+ data = {'image': img}
+ data = self.transform(data, self.ops[:1])
+ batch_file_names.append(file_name)
+ batch = self.transform(data, self.ops[1:])
+ others = None
+ if self.cfg['Architecture']['algorithm'] in [
+ 'SAR', 'RobustScanner'
+ ]:
+ valid_ratio = np.expand_dims(batch[-1], axis=0)
+ batch_others.append(valid_ratio)
+ # others = [torch.from_numpy(valid_ratio).to(device=self.device)]
+ resized_image = batch[0]
+ h, w = resized_image.shape[-2:]
+ max_width = max(max_width, w)
+ max_height = max(max_height, h)
+ batch_data.append(batch[0])
+
+ padded_batch_data = []
+ for resized_image in batch_data:
+ padded_image = np.zeros([1, 3, max_height, max_width],
+ dtype=np.float32)
+ h, w = resized_image.shape[-2:]
+
+ # Apply padding (bottom-right padding)
+ padded_image[:, :, :h, :
+ w] = resized_image # 0 is typically used for padding
+ padded_batch_data.append(padded_image)
+
+ if batch_others:
+ others = np.concatenate(batch_others, axis=0)
+ else:
+ others = None
+ images = np.concatenate(padded_batch_data, axis=0)
+ images = torch.from_numpy(images).to(device=self.device)
+
+ with torch.no_grad():
+ t_start = time.time()
+ preds = self.model(images, others)
+ t_cost = time.time() - t_start
+ post_results = self.post_process_class(preds)
+
+ for i, post_result in enumerate(post_results):
+ if img_path is not None:
+ info = {
+ 'file': batch_file_names[i],
+ 'text': post_result[0],
+ 'score': post_result[1],
+ 'elapse': t_cost
+ }
+ else:
+ info = {
+ 'text': post_result[0],
+ 'score': post_result[1],
+ 'elapse': t_cost
+ }
+ results.append(info)
+
+ return results
+
+
+def main(cfg):
+ model = OpenRecognizer(cfg)
+
+ save_res_path = cfg['Global']['output_dir']
+ if not os.path.exists(save_res_path):
+ os.makedirs(save_res_path)
+
+ t_sum = 0
+ sample_num = 0
+ max_len = cfg['Global']['max_text_length']
+ text_len_time = [0 for _ in range(max_len)]
+ text_len_num = [0 for _ in range(max_len)]
+
+ sample_num = 0
+ with open(save_res_path + '/rec_results.txt', 'wb') as fout:
+ for file in get_image_file_list(cfg['Global']['infer_img']):
+
+ preds_result = model(img_path=file, batch_num=1)[0]
+
+ rec_text = preds_result['text']
+ score = preds_result['score']
+ t_cost = preds_result['elapse']
+ info = rec_text + '\t' + str(score)
+ text_len_num[min(max_len - 1, len(rec_text))] += 1
+ text_len_time[min(max_len - 1, len(rec_text))] += t_cost
+ logger.info(
+ f'{sample_num} {file}\t result: {info}, time cost: {t_cost}')
+ otstr = file + '\t' + info + '\n'
+ t_sum += t_cost
+ fout.write(otstr.encode())
+ sample_num += 1
+
+ print(text_len_num)
+ w_avg_t_cost = []
+ for l_t_cost, l_num in zip(text_len_time, text_len_num):
+ if l_num != 0:
+ w_avg_t_cost.append(l_t_cost / l_num)
+ print(w_avg_t_cost)
+ w_avg_t_cost = sum(w_avg_t_cost) / len(w_avg_t_cost)
+
+ logger.info(
+ f'Sample num: {sample_num}, Weighted Avg time cost: {t_sum/sample_num}, Avg time cost: {w_avg_t_cost}'
+ )
+ logger.info('success!')
+
+
+if __name__ == '__main__':
+ FLAGS = ArgsParser().parse_args()
+ cfg = Config(FLAGS.config)
+ FLAGS = vars(FLAGS)
+ opt = FLAGS.pop('opt')
+ cfg.merge_dict(FLAGS)
+ cfg.merge_dict(opt)
+ main(cfg.cfg)
diff --git a/tools/train_rec.py b/tools/train_rec.py
new file mode 100644
index 0000000000000000000000000000000000000000..1351c97bef7e4e974331b22a8ee6ae742fe5b5b6
--- /dev/null
+++ b/tools/train_rec.py
@@ -0,0 +1,37 @@
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+from tools.engine import Config, Trainer
+from tools.utility import ArgsParser
+
+
+def parse_args():
+ parser = ArgsParser()
+ parser.add_argument(
+ '--eval',
+ action='store_true',
+ default=True,
+ help='Whether to perform evaluation in train',
+ )
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ FLAGS = parse_args()
+ cfg = Config(FLAGS.config)
+ FLAGS = vars(FLAGS)
+ opt = FLAGS.pop('opt')
+ cfg.merge_dict(FLAGS)
+ cfg.merge_dict(opt)
+ trainer = Trainer(cfg, mode='train_eval' if FLAGS['eval'] else 'train')
+ trainer.train()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/utility.py b/tools/utility.py
new file mode 100644
index 0000000000000000000000000000000000000000..86d83d802ff95e4d57fa3de7642efbcf5e8cc30d
--- /dev/null
+++ b/tools/utility.py
@@ -0,0 +1,45 @@
+from argparse import ArgumentParser, RawDescriptionHelpFormatter
+
+import yaml
+
+
+class ArgsParser(ArgumentParser):
+
+ def __init__(self):
+ super(ArgsParser,
+ self).__init__(formatter_class=RawDescriptionHelpFormatter)
+ self.add_argument('-c', '--config', help='configuration file to use')
+ self.add_argument('-o',
+ '--opt',
+ nargs='*',
+ help='set configuration options')
+ self.add_argument('--local_rank')
+ self.add_argument('--local-rank')
+
+ def parse_args(self, argv=None):
+ args = super(ArgsParser, self).parse_args(argv)
+ assert args.config is not None, 'Please specify --config=configure_file_path.'
+ args.opt = self._parse_opt(args.opt)
+ return args
+
+ def _parse_opt(self, opts):
+ config = {}
+ if not opts:
+ return config
+ for s in opts:
+ s = s.strip()
+ k, v = s.split('=', 1)
+ if '.' not in k:
+ config[k] = yaml.load(v, Loader=yaml.Loader)
+ else:
+ keys = k.split('.')
+ if keys[0] not in config:
+ config[keys[0]] = {}
+ cur = config[keys[0]]
+ for idx, key in enumerate(keys[1:]):
+ if idx == len(keys) - 2:
+ cur[key] = yaml.load(v, Loader=yaml.Loader)
+ else:
+ cur[key] = {}
+ cur = cur[key]
+ return config
diff --git a/tools/utils/EN_symbol_dict.txt b/tools/utils/EN_symbol_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1aef43d6b842731a54cbe682ccda5c2dbfa694d9
--- /dev/null
+++ b/tools/utils/EN_symbol_dict.txt
@@ -0,0 +1,94 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+:
+;
+<
+=
+>
+?
+@
+[
+\
+]
+^
+_
+`
+{
+|
+}
+~
\ No newline at end of file
diff --git a/tools/utils/__init__.py b/tools/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tools/utils/__pycache__/__init__.cpython-38.pyc b/tools/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7dca5ce44b79072c14eba0f4c7593d840f0c9eb
Binary files /dev/null and b/tools/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/tools/utils/__pycache__/ckpt.cpython-38.pyc b/tools/utils/__pycache__/ckpt.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d9ca7a98fb65bda7c8e24c12364be27bdbdf3fc
Binary files /dev/null and b/tools/utils/__pycache__/ckpt.cpython-38.pyc differ
diff --git a/tools/utils/__pycache__/logging.cpython-38.pyc b/tools/utils/__pycache__/logging.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..123e9ebd5803301683cb27c9a351d32584b88ffb
Binary files /dev/null and b/tools/utils/__pycache__/logging.cpython-38.pyc differ
diff --git a/tools/utils/__pycache__/stats.cpython-38.pyc b/tools/utils/__pycache__/stats.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d028130d7ba628e6faaedbaa91fcd398ee2550e5
Binary files /dev/null and b/tools/utils/__pycache__/stats.cpython-38.pyc differ
diff --git a/tools/utils/__pycache__/utility.cpython-38.pyc b/tools/utils/__pycache__/utility.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5d5fa60f778ffe00e70fc65e0743952389c6a70
Binary files /dev/null and b/tools/utils/__pycache__/utility.cpython-38.pyc differ
diff --git a/tools/utils/ckpt.py b/tools/utils/ckpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1ea8c83549e7c4251d650854bf1db42a78ba8fa
--- /dev/null
+++ b/tools/utils/ckpt.py
@@ -0,0 +1,87 @@
+import os
+
+import torch
+
+from tools.utils.logging import get_logger
+
+
+def save_ckpt(
+ model,
+ cfg,
+ optimizer,
+ lr_scheduler,
+ epoch,
+ global_step,
+ metrics,
+ is_best=False,
+ logger=None,
+ prefix=None,
+):
+ """
+ Saving checkpoints
+
+ :param epoch: current epoch number
+ :param log: logging information of the epoch
+ :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar'
+ """
+ if logger is None:
+ logger = get_logger()
+ if prefix is None:
+ if is_best:
+ save_path = os.path.join(cfg["Global"]["output_dir"], "best.pth")
+ else:
+ save_path = os.path.join(cfg["Global"]["output_dir"], "latest.pth")
+ else:
+ save_path = os.path.join(cfg["Global"]["output_dir"], prefix + ".pth")
+ state_dict = model.module.state_dict() if cfg["Global"]["distributed"] else model.state_dict()
+ state = {
+ "epoch": epoch,
+ "global_step": global_step,
+ "state_dict": state_dict,
+ "optimizer": None if is_best else optimizer.state_dict(),
+ "scheduler": None if is_best else lr_scheduler.state_dict(),
+ "config": cfg,
+ "metrics": metrics,
+ }
+ torch.save(state, save_path)
+ logger.info(f"save ckpt to {save_path}")
+
+
+def load_ckpt(model, cfg, optimizer=None, lr_scheduler=None, logger=None):
+ """
+ Resume from saved checkpoints
+ :param checkpoint_path: Checkpoint path to be resumed
+ """
+ if logger is None:
+ logger = get_logger()
+ checkpoints = cfg["Global"].get("checkpoints")
+ pretrained_model = cfg["Global"].get("pretrained_model")
+
+ status = {}
+ if checkpoints and os.path.exists(checkpoints):
+ checkpoint = torch.load(checkpoints, map_location=torch.device("cpu"))
+ model.load_state_dict(checkpoint["state_dict"], strict=True)
+ if optimizer is not None:
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ if lr_scheduler is not None:
+ lr_scheduler.load_state_dict(checkpoint["scheduler"])
+ logger.info(f"resume from checkpoint {checkpoints} (epoch {checkpoint['epoch']})")
+
+ status["global_step"] = checkpoint["global_step"]
+ status["epoch"] = checkpoint["epoch"] + 1
+ status["metrics"] = checkpoint["metrics"]
+ elif pretrained_model and os.path.exists(pretrained_model):
+ load_pretrained_params(model, pretrained_model, logger)
+ logger.info(f"finetune from checkpoint {pretrained_model}")
+ else:
+ logger.info("train from scratch")
+ return status
+
+
+def load_pretrained_params(model, pretrained_model, logger):
+ checkpoint = torch.load(pretrained_model, map_location=torch.device("cpu"))
+ model.load_state_dict(checkpoint["state_dict"], strict=False)
+ for name in model.state_dict().keys():
+ if name not in checkpoint["state_dict"]:
+ logger.info(f"{name} is not in pretrained model")
+
diff --git a/tools/utils/dict/ar_dict.txt b/tools/utils/dict/ar_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fc6380293eb51754b3d82a4b20ce4bdc297d56ed
--- /dev/null
+++ b/tools/utils/dict/ar_dict.txt
@@ -0,0 +1,117 @@
+a
+r
+b
+i
+c
+_
+m
+g
+/
+1
+0
+I
+L
+S
+V
+R
+C
+2
+v
+l
+6
+3
+9
+.
+j
+p
+ا
+ل
+م
+ر
+ج
+و
+ح
+ي
+ة
+5
+8
+7
+أ
+ب
+ض
+4
+ك
+س
+ه
+ث
+ن
+ط
+ع
+ت
+غ
+خ
+ف
+ئ
+ز
+إ
+د
+ص
+ظ
+ذ
+ش
+ى
+ق
+ؤ
+آ
+ء
+s
+e
+n
+w
+t
+u
+z
+d
+A
+N
+G
+h
+o
+E
+T
+H
+O
+B
+y
+F
+U
+J
+X
+W
+P
+Z
+M
+k
+q
+Y
+Q
+D
+f
+K
+x
+'
+%
+-
+#
+@
+!
+&
+$
+,
+:
+é
+?
++
+É
+(
+
diff --git a/tools/utils/dict/arabic_dict.txt b/tools/utils/dict/arabic_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..916d421c53bad563dfd980c1b64dcce07a3c9d24
--- /dev/null
+++ b/tools/utils/dict/arabic_dict.txt
@@ -0,0 +1,161 @@
+!
+#
+$
+%
+&
+'
+(
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+É
+é
+ء
+آ
+أ
+ؤ
+إ
+ئ
+ا
+ب
+ة
+ت
+ث
+ج
+ح
+خ
+د
+ذ
+ر
+ز
+س
+ش
+ص
+ض
+ط
+ظ
+ع
+غ
+ف
+ق
+ك
+ل
+م
+ن
+ه
+و
+ى
+ي
+ً
+ٌ
+ٍ
+َ
+ُ
+ِ
+ّ
+ْ
+ٓ
+ٔ
+ٰ
+ٱ
+ٹ
+پ
+چ
+ڈ
+ڑ
+ژ
+ک
+ڭ
+گ
+ں
+ھ
+ۀ
+ہ
+ۂ
+ۃ
+ۆ
+ۇ
+ۈ
+ۋ
+ی
+ې
+ے
+ۓ
+ە
+١
+٢
+٣
+٤
+٥
+٦
+٧
+٨
+٩
diff --git a/tools/utils/dict/be_dict.txt b/tools/utils/dict/be_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f8458baaf2f1b1da82dc56bd259ca6fb3e887b89
--- /dev/null
+++ b/tools/utils/dict/be_dict.txt
@@ -0,0 +1,145 @@
+b
+e
+_
+i
+m
+g
+/
+2
+0
+I
+L
+S
+V
+R
+C
+1
+v
+a
+l
+6
+9
+4
+3
+.
+j
+p
+п
+а
+з
+б
+у
+г
+н
+ц
+ь
+8
+м
+л
+і
+о
+ў
+ы
+7
+5
+М
+х
+с
+р
+ф
+я
+е
+д
+ж
+ю
+ч
+й
+к
+Д
+в
+Б
+т
+І
+ш
+ё
+э
+К
+Л
+Н
+А
+Ж
+Г
+В
+П
+З
+Е
+О
+Р
+С
+У
+Ё
+Й
+Т
+Ч
+Э
+Ц
+Ю
+Ш
+Ф
+Х
+Я
+Ь
+Ы
+Ў
+s
+c
+n
+w
+M
+o
+t
+T
+E
+A
+B
+u
+h
+y
+k
+r
+H
+d
+Y
+O
+U
+F
+f
+x
+D
+G
+N
+K
+P
+z
+J
+X
+W
+Z
+Q
+%
+-
+q
+@
+'
+!
+#
+&
+,
+:
+$
+(
+?
+é
++
+É
+
diff --git a/tools/utils/dict/bg_dict.txt b/tools/utils/dict/bg_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..84713c373b5df1f59d51bd9c505721d8ec239b98
--- /dev/null
+++ b/tools/utils/dict/bg_dict.txt
@@ -0,0 +1,140 @@
+!
+#
+$
+%
+&
+'
+(
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+É
+é
+А
+Б
+В
+Г
+Д
+Е
+Ж
+З
+И
+Й
+К
+Л
+М
+Н
+О
+П
+Р
+С
+Т
+У
+Ф
+Х
+Ц
+Ч
+Ш
+Щ
+Ъ
+Ю
+Я
+а
+б
+в
+г
+д
+е
+ж
+з
+и
+й
+к
+л
+м
+н
+о
+п
+р
+с
+т
+у
+ф
+х
+ц
+ч
+ш
+щ
+ъ
+ь
+ю
+я
+
diff --git a/tools/utils/dict/chinese_cht_dict.txt b/tools/utils/dict/chinese_cht_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cc1aa4724b9a6f0e15275bcf61c91c26b6550c3e
--- /dev/null
+++ b/tools/utils/dict/chinese_cht_dict.txt
@@ -0,0 +1,8421 @@
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+[
+\
+]
+^
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+|
+}
+~
+¥
+®
+°
+±
+²
+´
+·
+»
+É
+Ë
+Ó
+×
+Ü
+à
+á
+ä
+è
+é
+ì
+í
+ò
+ó
+÷
+ú
+ü
+ā
+ē
+ī
+ō
+ū
+ǐ
+ǒ
+ɔ
+ɡ
+ʌ
+ˋ
+Λ
+Ο
+Φ
+Ω
+α
+β
+ε
+θ
+μ
+π
+З
+И
+Й
+П
+Я
+г
+—
+‖
+‘
+’
+“
+”
+•
+…
+‧
+′
+″
+※
+℃
+№
+™
+Ⅱ
+Ⅲ
+Ⅳ
+←
+↑
+→
+↓
+⇋
+∈
+∑
+√
+∞
+∣
+∧
+∩
+∫
+∶
+≈
+≠
+≤
+≥
+⊙
+⊥
+①
+②
+③
+④
+⑧
+⑴
+⑵
+⑶
+─
+│
+┅
+┌
+├
+█
+▎
+▏
+▕
+■
+□
+▪
+▲
+△
+▼
+◆
+◇
+○
+◎
+●
+◥
+★
+☆
+❋
+❤
+
+、
+。
+〇
+〉
+《
+》
+「
+」
+『
+』
+【
+】
+〔
+〕
+〖
+〗
+の
+サ
+シ
+ジ
+マ
+ㄱ
+ㆍ
+㎏
+㎡
+㐂
+㐱
+㙟
+㴪
+㸃
+䖝
+䝉
+䰾
+䲁
+一
+丁
+七
+丄
+丈
+三
+上
+下
+丌
+不
+与
+丏
+丐
+丑
+且
+丕
+世
+丘
+丙
+丞
+丟
+両
+並
+丨
+丫
+中
+丰
+串
+丶
+丸
+丹
+主
+丼
+丿
+乂
+乃
+久
+么
+之
+乍
+乎
+乏
+乒
+乓
+乖
+乗
+乘
+乙
+乚
+乜
+九
+乞
+也
+乩
+乭
+乳
+乸
+乹
+乾
+亀
+亂
+亅
+了
+予
+亊
+事
+二
+亍
+云
+互
+亓
+五
+井
+亘
+些
+亜
+亞
+亟
+亠
+亡
+亢
+交
+亥
+亦
+亨
+享
+京
+亭
+亮
+亰
+亳
+亶
+亹
+人
+亻
+什
+仁
+仂
+仃
+仄
+仇
+仉
+今
+介
+仍
+仏
+仔
+仕
+他
+仗
+付
+仙
+仛
+仝
+仞
+仟
+仡
+代
+令
+以
+仨
+仫
+仮
+仰
+仲
+仳
+仵
+件
+仺
+任
+仼
+份
+仿
+企
+伃
+伈
+伉
+伊
+伋
+伍
+伎
+伏
+伐
+休
+伕
+伙
+伝
+伢
+伯
+估
+伱
+伴
+伶
+伷
+伸
+伺
+似
+伽
+伾
+佀
+佁
+佃
+但
+佇
+佈
+佉
+佋
+位
+低
+住
+佐
+佑
+体
+佔
+何
+佗
+佘
+余
+佚
+佛
+作
+佝
+佞
+佟
+你
+佣
+佤
+佧
+佩
+佬
+佯
+佰
+佳
+併
+佶
+佹
+佺
+佼
+佾
+使
+侁
+侃
+侄
+侅
+來
+侈
+侊
+例
+侍
+侏
+侑
+侖
+侗
+侘
+侚
+供
+依
+侞
+価
+侮
+侯
+侵
+侶
+侷
+侹
+便
+俁
+係
+促
+俄
+俅
+俊
+俋
+俌
+俍
+俎
+俏
+俐
+俑
+俗
+俘
+俚
+俛
+保
+俞
+俟
+俠
+信
+俬
+修
+俯
+俱
+俳
+俴
+俵
+俶
+俸
+俺
+俽
+俾
+倆
+倈
+倉
+個
+倌
+倍
+們
+倒
+倓
+倔
+倖
+倗
+倘
+候
+倚
+倜
+倞
+借
+倡
+倢
+倣
+値
+倦
+倧
+倩
+倪
+倫
+倬
+倭
+倮
+倻
+值
+偁
+偃
+假
+偈
+偉
+偊
+偌
+偍
+偎
+偏
+偓
+偕
+做
+停
+健
+偪
+偲
+側
+偵
+偶
+偷
+偸
+偽
+傀
+傃
+傅
+傈
+傉
+傍
+傑
+傒
+傕
+傖
+傘
+備
+傜
+傢
+傣
+催
+傭
+傲
+傳
+債
+傷
+傻
+傾
+僅
+僉
+僊
+働
+像
+僑
+僔
+僕
+僖
+僙
+僚
+僜
+僡
+僧
+僩
+僭
+僮
+僰
+僱
+僳
+僴
+僵
+價
+僻
+儀
+儁
+儂
+億
+儆
+儇
+儈
+儉
+儋
+儐
+儒
+儔
+儕
+儘
+儚
+儞
+償
+儡
+儥
+儦
+優
+儫
+儱
+儲
+儷
+儺
+儻
+儼
+兀
+允
+元
+兄
+充
+兆
+先
+光
+克
+兌
+免
+児
+兒
+兔
+兕
+兗
+兜
+入
+內
+全
+兩
+兪
+八
+公
+六
+兮
+共
+兵
+其
+具
+典
+兼
+兿
+冀
+冂
+円
+冇
+冉
+冊
+再
+冏
+冑
+冒
+冕
+冖
+冗
+冚
+冠
+冢
+冤
+冥
+冧
+冨
+冪
+冫
+冬
+冮
+冰
+冴
+冶
+冷
+冼
+冽
+凃
+凄
+准
+凈
+凋
+凌
+凍
+凖
+凜
+凝
+凞
+几
+凡
+処
+凪
+凬
+凰
+凱
+凳
+凵
+凶
+凸
+凹
+出
+函
+刀
+刁
+刂
+刃
+刄
+分
+切
+刈
+刊
+刎
+刑
+划
+列
+初
+判
+別
+刦
+刧
+刨
+利
+刪
+刮
+到
+制
+刷
+券
+刺
+刻
+刼
+剁
+剃
+則
+削
+剋
+剌
+前
+剎
+剏
+剔
+剖
+剛
+剝
+剡
+剣
+剩
+剪
+剮
+副
+割
+創
+剿
+劃
+劄
+劇
+劈
+劉
+劊
+劌
+劍
+劑
+劔
+力
+功
+加
+劣
+助
+努
+劫
+劬
+劭
+劵
+効
+劼
+劾
+勁
+勃
+勅
+勇
+勉
+勐
+勑
+勒
+勔
+動
+勖
+勗
+勘
+務
+勛
+勝
+勞
+募
+勢
+勣
+勤
+勦
+勰
+勱
+勲
+勳
+勵
+勷
+勸
+勺
+勻
+勾
+勿
+匂
+匄
+包
+匆
+匈
+匋
+匍
+匏
+匐
+匕
+化
+北
+匙
+匚
+匝
+匠
+匡
+匣
+匪
+匯
+匱
+匸
+匹
+匾
+匿
+區
+十
+千
+卅
+升
+午
+卉
+半
+卋
+卍
+卐
+卑
+卒
+卓
+協
+南
+博
+卜
+卞
+卟
+占
+卡
+卣
+卦
+卧
+卩
+卬
+卮
+卯
+印
+危
+卲
+即
+卵
+卷
+卸
+卹
+卺
+卻
+卽
+卿
+厄
+厓
+厔
+厙
+厚
+厝
+原
+厥
+厭
+厰
+厲
+厴
+厶
+去
+參
+叄
+又
+叉
+及
+友
+反
+収
+叔
+叕
+取
+受
+叛
+叟
+叡
+叢
+口
+古
+句
+另
+叨
+叩
+只
+叫
+召
+叭
+叮
+可
+台
+叱
+史
+右
+叵
+司
+叻
+叼
+吁
+吃
+各
+吆
+合
+吉
+吊
+吋
+同
+名
+后
+吏
+吐
+向
+吒
+吔
+吖
+君
+吝
+吞
+吟
+吠
+吡
+吥
+否
+吧
+吩
+含
+吮
+吱
+吲
+吳
+吵
+吶
+吸
+吹
+吻
+吼
+吾
+呀
+呂
+呃
+呈
+呉
+告
+呋
+呎
+呢
+呤
+呦
+周
+呱
+味
+呵
+呷
+呸
+呼
+命
+呾
+咀
+咁
+咂
+咄
+咅
+咆
+咋
+和
+咎
+咑
+咒
+咔
+咕
+咖
+咗
+咘
+咚
+咟
+咤
+咥
+咧
+咨
+咩
+咪
+咫
+咬
+咭
+咯
+咱
+咲
+咳
+咸
+咻
+咼
+咽
+咾
+咿
+哀
+品
+哂
+哄
+哆
+哇
+哈
+哉
+哌
+哎
+哏
+哐
+哖
+哚
+哞
+員
+哥
+哦
+哨
+哩
+哪
+哭
+哮
+哱
+哲
+哺
+哼
+唃
+唄
+唆
+唇
+唉
+唏
+唐
+唑
+唔
+唘
+唧
+唫
+唬
+唭
+售
+唯
+唱
+唳
+唵
+唷
+唸
+唻
+唾
+啁
+啃
+啄
+商
+啉
+啊
+啍
+問
+啓
+啖
+啚
+啜
+啞
+啟
+啡
+啣
+啤
+啥
+啦
+啪
+啫
+啯
+啰
+啱
+啲
+啵
+啶
+啷
+啻
+啼
+啾
+喀
+喂
+喃
+善
+喆
+喇
+喈
+喉
+喊
+喋
+喏
+喔
+喘
+喙
+喚
+喜
+喝
+喢
+喦
+喧
+喪
+喫
+喬
+單
+喰
+喱
+喲
+喳
+喵
+喹
+喻
+喼
+嗄
+嗅
+嗆
+嗇
+嗊
+嗎
+嗑
+嗒
+嗓
+嗔
+嗖
+嗚
+嗜
+嗝
+嗞
+嗡
+嗢
+嗣
+嗦
+嗨
+嗩
+嗪
+嗮
+嗯
+嗲
+嗶
+嗹
+嗽
+嘀
+嘅
+嘆
+嘉
+嘌
+嘍
+嘎
+嘏
+嘔
+嘗
+嘚
+嘛
+嘜
+嘞
+嘟
+嘢
+嘣
+嘥
+嘧
+嘩
+嘬
+嘮
+嘯
+嘰
+嘲
+嘴
+嘶
+嘸
+嘹
+嘻
+嘿
+噁
+噌
+噍
+噏
+噓
+噗
+噝
+噠
+噢
+噤
+噥
+噦
+器
+噩
+噪
+噬
+噯
+噰
+噲
+噴
+噶
+噸
+噹
+噻
+嚇
+嚈
+嚎
+嚏
+嚐
+嚒
+嚓
+嚕
+嚗
+嚙
+嚞
+嚟
+嚤
+嚦
+嚧
+嚨
+嚩
+嚮
+嚳
+嚴
+嚶
+嚷
+嚼
+嚿
+囀
+囂
+囃
+囉
+囊
+囍
+囑
+囒
+囓
+囗
+囚
+四
+囝
+回
+因
+囡
+団
+囤
+囧
+囪
+囮
+囯
+困
+囲
+図
+囶
+囷
+囹
+固
+囿
+圂
+圃
+圄
+圈
+圉
+國
+圍
+圏
+園
+圓
+圖
+圗
+團
+圜
+土
+圧
+在
+圩
+圪
+圭
+圯
+地
+圳
+圻
+圾
+址
+均
+坊
+坋
+坌
+坍
+坎
+坐
+坑
+坖
+坡
+坣
+坤
+坦
+坨
+坩
+坪
+坫
+坬
+坭
+坮
+坯
+坳
+坵
+坶
+坷
+坻
+垂
+垃
+垈
+型
+垍
+垓
+垕
+垚
+垛
+垞
+垟
+垠
+垢
+垣
+垮
+垯
+垰
+垵
+垸
+垻
+垿
+埃
+埅
+埇
+埈
+埋
+埌
+城
+埏
+埒
+埔
+埕
+埗
+埜
+域
+埠
+埡
+埤
+埧
+埨
+埪
+埭
+埮
+埴
+埵
+執
+培
+基
+埻
+埼
+堀
+堂
+堃
+堅
+堆
+堇
+堈
+堉
+堊
+堍
+堖
+堝
+堡
+堤
+堦
+堪
+堮
+堯
+堰
+報
+場
+堵
+堷
+堺
+塀
+塅
+塆
+塊
+塋
+塌
+塍
+塏
+塑
+塔
+塗
+塘
+塙
+塜
+塞
+塡
+塢
+塤
+塨
+塩
+填
+塬
+塭
+塰
+塱
+塲
+塵
+塹
+塽
+塾
+墀
+境
+墅
+墉
+墊
+墎
+墓
+増
+墘
+墜
+增
+墟
+墡
+墣
+墨
+墩
+墫
+墬
+墮
+墱
+墳
+墺
+墼
+墾
+壁
+壄
+壆
+壇
+壋
+壌
+壎
+壐
+壑
+壓
+壔
+壕
+壘
+壙
+壞
+壟
+壠
+壢
+壤
+壩
+士
+壬
+壯
+壱
+壴
+壹
+壺
+壽
+夀
+夆
+変
+夊
+夋
+夌
+夏
+夔
+夕
+外
+夙
+多
+夜
+夠
+夢
+夤
+夥
+大
+天
+太
+夫
+夬
+夭
+央
+夯
+失
+夷
+夾
+奀
+奄
+奇
+奈
+奉
+奎
+奏
+奐
+契
+奓
+奔
+奕
+套
+奘
+奚
+奠
+奢
+奣
+奧
+奩
+奪
+奫
+奭
+奮
+女
+奴
+奶
+她
+好
+妀
+妁
+如
+妃
+妄
+妊
+妍
+妏
+妑
+妒
+妓
+妖
+妙
+妝
+妞
+妠
+妤
+妥
+妧
+妨
+妭
+妮
+妯
+妲
+妳
+妸
+妹
+妺
+妻
+妾
+姀
+姁
+姃
+姆
+姈
+姉
+姊
+始
+姌
+姍
+姐
+姑
+姒
+姓
+委
+姚
+姜
+姝
+姣
+姥
+姦
+姨
+姪
+姫
+姬
+姮
+姵
+姶
+姸
+姻
+姿
+威
+娃
+娉
+娋
+娌
+娍
+娎
+娑
+娖
+娘
+娛
+娜
+娟
+娠
+娣
+娥
+娩
+娫
+娳
+娶
+娸
+娼
+娽
+婀
+婁
+婆
+婉
+婊
+婑
+婕
+婚
+婢
+婦
+婧
+婪
+婭
+婯
+婷
+婺
+婻
+婼
+婿
+媃
+媄
+媊
+媐
+媒
+媓
+媖
+媗
+媚
+媛
+媜
+媞
+媧
+媭
+媯
+媲
+媳
+媺
+媼
+媽
+媾
+媿
+嫁
+嫂
+嫄
+嫈
+嫉
+嫌
+嫖
+嫘
+嫚
+嫡
+嫣
+嫦
+嫩
+嫪
+嫲
+嫳
+嫵
+嫺
+嫻
+嬅
+嬈
+嬉
+嬋
+嬌
+嬗
+嬛
+嬝
+嬡
+嬤
+嬨
+嬪
+嬬
+嬭
+嬰
+嬴
+嬸
+嬾
+嬿
+孀
+孃
+孆
+孋
+孌
+子
+孑
+孔
+孕
+孖
+字
+存
+孚
+孛
+孜
+孝
+孟
+孢
+季
+孤
+孩
+孫
+孬
+孮
+孰
+孳
+孵
+學
+孺
+孻
+孽
+孿
+宀
+它
+宅
+宇
+守
+安
+宋
+完
+宍
+宏
+宓
+宕
+宗
+官
+宙
+定
+宛
+宜
+実
+客
+宣
+室
+宥
+宦
+宧
+宮
+宰
+害
+宴
+宵
+家
+宸
+容
+宿
+寀
+寁
+寂
+寄
+寅
+密
+寇
+寈
+寊
+富
+寐
+寒
+寓
+寔
+寕
+寖
+寗
+寘
+寛
+寜
+寞
+察
+寡
+寢
+寤
+寥
+實
+寧
+寨
+審
+寫
+寬
+寮
+寯
+寰
+寳
+寵
+寶
+寸
+寺
+対
+封
+専
+尃
+射
+將
+專
+尉
+尊
+尋
+對
+導
+小
+尐
+少
+尓
+尕
+尖
+尗
+尙
+尚
+尢
+尤
+尨
+尪
+尬
+就
+尷
+尹
+尺
+尻
+尼
+尾
+尿
+局
+屁
+屄
+居
+屆
+屇
+屈
+屋
+屌
+屍
+屎
+屏
+屐
+屑
+屓
+展
+屚
+屜
+屠
+屢
+層
+履
+屬
+屭
+屯
+山
+屹
+屺
+屻
+岀
+岈
+岌
+岐
+岑
+岔
+岡
+岢
+岣
+岧
+岩
+岪
+岫
+岬
+岰
+岱
+岳
+岵
+岷
+岸
+岻
+峁
+峅
+峇
+峋
+峍
+峒
+峘
+峙
+峚
+峠
+峨
+峩
+峪
+峭
+峯
+峰
+峴
+島
+峻
+峼
+峽
+崁
+崆
+崇
+崈
+崋
+崍
+崎
+崐
+崑
+崒
+崔
+崖
+崗
+崘
+崙
+崚
+崛
+崞
+崟
+崠
+崢
+崤
+崧
+崩
+崬
+崮
+崱
+崴
+崵
+崶
+崽
+嵇
+嵊
+嵋
+嵌
+嵎
+嵐
+嵒
+嵕
+嵖
+嵗
+嵙
+嵛
+嵜
+嵨
+嵩
+嵬
+嵮
+嵯
+嵰
+嵴
+嵻
+嵿
+嶁
+嶂
+嶃
+嶄
+嶇
+嶋
+嶌
+嶍
+嶒
+嶔
+嶗
+嶝
+嶠
+嶢
+嶦
+嶧
+嶪
+嶬
+嶰
+嶲
+嶴
+嶷
+嶸
+嶺
+嶼
+嶽
+巂
+巄
+巆
+巋
+巌
+巍
+巎
+巑
+巒
+巔
+巖
+巘
+巛
+川
+州
+巡
+巢
+工
+左
+巧
+巨
+巫
+差
+巰
+己
+已
+巳
+巴
+巶
+巷
+巻
+巽
+巾
+巿
+市
+布
+帆
+希
+帑
+帔
+帕
+帖
+帘
+帙
+帚
+帛
+帝
+帡
+帢
+帥
+師
+席
+帯
+帰
+帳
+帶
+帷
+常
+帽
+幀
+幃
+幄
+幅
+幌
+幔
+幕
+幗
+幚
+幛
+幟
+幡
+幢
+幣
+幪
+幫
+干
+平
+年
+幵
+幷
+幸
+幹
+幺
+幻
+幼
+幽
+幾
+庀
+庁
+広
+庇
+床
+序
+底
+庖
+店
+庚
+府
+庠
+庢
+庥
+度
+座
+庫
+庭
+庲
+庵
+庶
+康
+庸
+庹
+庼
+庾
+廁
+廂
+廄
+廆
+廈
+廉
+廊
+廋
+廌
+廍
+廑
+廓
+廔
+廕
+廖
+廙
+廚
+廝
+廞
+廟
+廠
+廡
+廢
+廣
+廧
+廨
+廩
+廬
+廰
+廱
+廳
+延
+廷
+廸
+建
+廻
+廼
+廿
+弁
+弄
+弅
+弇
+弈
+弉
+弊
+弋
+弍
+式
+弐
+弒
+弓
+弔
+引
+弖
+弗
+弘
+弛
+弟
+弢
+弦
+弧
+弨
+弩
+弭
+弱
+張
+強
+弸
+弼
+弾
+彀
+彄
+彅
+彆
+彈
+彊
+彌
+彎
+彐
+彔
+彖
+彗
+彘
+彙
+彜
+彞
+彠
+彡
+形
+彣
+彤
+彥
+彧
+彩
+彪
+彫
+彬
+彭
+彰
+影
+彳
+彷
+役
+彼
+彿
+往
+征
+徂
+待
+徇
+很
+徉
+徊
+律
+後
+徐
+徑
+徒
+得
+徘
+徙
+徜
+從
+徠
+御
+徧
+徨
+復
+循
+徫
+徬
+徭
+微
+徳
+徴
+徵
+德
+徸
+徹
+徽
+心
+忄
+必
+忉
+忌
+忍
+忐
+忑
+忒
+志
+忘
+忙
+応
+忝
+忞
+忠
+快
+忬
+忯
+忱
+忳
+念
+忻
+忽
+忿
+怍
+怎
+怒
+怕
+怖
+怙
+怛
+思
+怠
+怡
+急
+怦
+性
+怨
+怪
+怯
+怵
+恁
+恂
+恃
+恆
+恊
+恍
+恐
+恕
+恙
+恢
+恣
+恤
+恥
+恨
+恩
+恪
+恬
+恭
+息
+恰
+恵
+恿
+悄
+悅
+悆
+悉
+悌
+悍
+悔
+悖
+悚
+悛
+悝
+悞
+悟
+悠
+患
+悧
+您
+悪
+悰
+悲
+悳
+悵
+悶
+悸
+悼
+情
+惆
+惇
+惑
+惔
+惕
+惘
+惚
+惜
+惟
+惠
+惡
+惣
+惦
+惰
+惱
+惲
+想
+惶
+惹
+惺
+愁
+愃
+愆
+愈
+愉
+愍
+意
+愐
+愒
+愔
+愕
+愚
+愛
+愜
+感
+愣
+愧
+愨
+愫
+愭
+愴
+愷
+愼
+愾
+愿
+慄
+慈
+態
+慌
+慎
+慕
+慘
+慚
+慜
+慟
+慢
+慣
+慥
+慧
+慨
+慮
+慰
+慳
+慵
+慶
+慷
+慾
+憂
+憊
+憋
+憍
+憎
+憐
+憑
+憓
+憕
+憙
+憚
+憤
+憧
+憨
+憩
+憫
+憬
+憲
+憶
+憺
+憻
+憾
+懂
+懃
+懇
+懈
+應
+懋
+懌
+懍
+懐
+懣
+懦
+懮
+懲
+懵
+懶
+懷
+懸
+懺
+懼
+懽
+懾
+懿
+戀
+戇
+戈
+戊
+戌
+戍
+戎
+成
+我
+戒
+戔
+戕
+或
+戙
+戚
+戛
+戟
+戡
+戢
+戥
+戦
+戩
+截
+戮
+戰
+戱
+戲
+戳
+戴
+戶
+戸
+戻
+戽
+戾
+房
+所
+扁
+扆
+扇
+扈
+扉
+手
+扌
+才
+扎
+扒
+打
+扔
+托
+扙
+扛
+扞
+扣
+扥
+扦
+扭
+扮
+扯
+扳
+扶
+批
+扼
+找
+承
+技
+抃
+抄
+抇
+抉
+把
+抑
+抒
+抓
+投
+抖
+抗
+折
+抦
+披
+抬
+抱
+抵
+抹
+抻
+押
+抽
+抿
+拂
+拆
+拇
+拈
+拉
+拋
+拌
+拍
+拎
+拏
+拐
+拒
+拓
+拔
+拖
+拗
+拘
+拙
+拚
+招
+拜
+拝
+拡
+括
+拭
+拮
+拯
+拱
+拳
+拴
+拷
+拺
+拼
+拽
+拾
+拿
+持
+指
+按
+挎
+挑
+挖
+挙
+挨
+挪
+挫
+振
+挲
+挵
+挹
+挺
+挻
+挾
+捂
+捆
+捉
+捌
+捍
+捎
+捏
+捐
+捒
+捕
+捜
+捦
+捧
+捨
+捩
+捫
+捭
+捱
+捲
+捶
+捷
+捺
+捻
+掀
+掂
+掃
+掄
+掇
+授
+掉
+掌
+掏
+掐
+排
+掖
+掘
+掙
+掛
+掞
+掟
+掠
+採
+探
+掣
+接
+控
+推
+掩
+措
+掬
+掰
+掾
+揀
+揄
+揆
+揉
+揍
+描
+提
+插
+揔
+揖
+揚
+換
+握
+揪
+揭
+揮
+援
+揸
+揺
+損
+搏
+搐
+搓
+搔
+搖
+搗
+搜
+搞
+搠
+搢
+搪
+搬
+搭
+搳
+搴
+搵
+搶
+搽
+搾
+摂
+摒
+摔
+摘
+摜
+摞
+摟
+摠
+摧
+摩
+摭
+摯
+摳
+摴
+摵
+摶
+摸
+摹
+摺
+摻
+摽
+撃
+撇
+撈
+撐
+撒
+撓
+撕
+撖
+撙
+撚
+撞
+撣
+撤
+撥
+撩
+撫
+撬
+播
+撮
+撰
+撲
+撳
+撻
+撼
+撾
+撿
+擀
+擁
+擂
+擅
+擇
+擊
+擋
+操
+擎
+擒
+擔
+擘
+據
+擠
+擢
+擥
+擦
+擬
+擯
+擰
+擱
+擲
+擴
+擷
+擺
+擼
+擾
+攀
+攏
+攔
+攖
+攘
+攜
+攝
+攞
+攢
+攣
+攤
+攪
+攫
+攬
+支
+攴
+攵
+收
+攷
+攸
+改
+攻
+攽
+放
+政
+故
+效
+敍
+敎
+敏
+救
+敔
+敕
+敖
+敗
+敘
+教
+敝
+敞
+敟
+敢
+散
+敦
+敫
+敬
+敭
+敲
+整
+敵
+敷
+數
+敻
+敾
+斂
+斃
+文
+斌
+斎
+斐
+斑
+斕
+斖
+斗
+料
+斛
+斜
+斝
+斟
+斡
+斤
+斥
+斧
+斬
+斯
+新
+斷
+方
+於
+施
+斿
+旁
+旂
+旃
+旄
+旅
+旉
+旋
+旌
+旎
+族
+旖
+旗
+旙
+旛
+旡
+既
+日
+旦
+旨
+早
+旬
+旭
+旱
+旲
+旳
+旺
+旻
+旼
+旽
+旾
+旿
+昀
+昂
+昃
+昆
+昇
+昉
+昊
+昌
+昍
+明
+昏
+昐
+易
+昔
+昕
+昚
+昛
+昜
+昝
+昞
+星
+映
+昡
+昣
+昤
+春
+昧
+昨
+昪
+昫
+昭
+是
+昰
+昱
+昴
+昵
+昶
+昺
+晁
+時
+晃
+晈
+晉
+晊
+晏
+晗
+晙
+晚
+晛
+晝
+晞
+晟
+晤
+晦
+晧
+晨
+晩
+晪
+晫
+晭
+普
+景
+晰
+晳
+晴
+晶
+晷
+晸
+智
+晾
+暃
+暄
+暅
+暇
+暈
+暉
+暊
+暌
+暎
+暏
+暐
+暑
+暕
+暖
+暗
+暘
+暝
+暟
+暠
+暢
+暦
+暨
+暫
+暮
+暱
+暲
+暴
+暸
+暹
+暻
+暾
+曄
+曅
+曆
+曇
+曉
+曌
+曔
+曖
+曙
+曜
+曝
+曠
+曦
+曧
+曨
+曩
+曬
+曮
+曰
+曲
+曳
+更
+曶
+曷
+書
+曹
+曺
+曼
+曽
+曾
+替
+最
+會
+月
+有
+朊
+朋
+服
+朏
+朐
+朓
+朔
+朕
+朖
+朗
+望
+朝
+期
+朦
+朧
+木
+未
+末
+本
+札
+朱
+朴
+朵
+朶
+朽
+朿
+杁
+杉
+杋
+杌
+李
+杏
+材
+村
+杓
+杖
+杙
+杜
+杞
+束
+杠
+杣
+杤
+杧
+杬
+杭
+杯
+東
+杲
+杳
+杴
+杵
+杷
+杻
+杼
+松
+板
+极
+枇
+枉
+枋
+枏
+析
+枕
+枖
+林
+枚
+枛
+果
+枝
+枠
+枡
+枯
+枰
+枱
+枲
+枳
+架
+枷
+枸
+枹
+枼
+柁
+柃
+柄
+柉
+柊
+柎
+柏
+某
+柑
+柒
+染
+柔
+柘
+柚
+柜
+柝
+柞
+柟
+查
+柩
+柬
+柯
+柰
+柱
+柳
+柴
+柵
+柶
+柷
+査
+柾
+柿
+栃
+栄
+栐
+栒
+栓
+栜
+栝
+栞
+校
+栢
+栨
+栩
+株
+栲
+栴
+核
+根
+栻
+格
+栽
+桀
+桁
+桂
+桃
+桄
+桅
+框
+案
+桉
+桌
+桎
+桐
+桑
+桓
+桔
+桕
+桖
+桙
+桜
+桝
+桫
+桱
+桲
+桴
+桶
+桷
+桼
+桿
+梀
+梁
+梂
+梃
+梅
+梆
+梉
+梏
+梓
+梔
+梗
+梘
+條
+梟
+梠
+梢
+梣
+梧
+梨
+梫
+梭
+梯
+械
+梱
+梳
+梵
+梶
+梽
+棄
+棆
+棉
+棋
+棍
+棐
+棒
+棓
+棕
+棖
+棗
+棘
+棚
+棛
+棟
+棠
+棡
+棣
+棧
+棨
+棩
+棪
+棫
+森
+棱
+棲
+棵
+棶
+棹
+棺
+棻
+棼
+棽
+椅
+椆
+椇
+椋
+植
+椎
+椏
+椒
+椙
+椥
+椪
+椰
+椲
+椴
+椵
+椹
+椽
+椿
+楂
+楊
+楓
+楔
+楗
+楙
+楚
+楝
+楞
+楠
+楡
+楢
+楣
+楤
+楦
+楧
+楨
+楫
+業
+楮
+楯
+楳
+極
+楷
+楸
+楹
+楽
+楿
+概
+榆
+榊
+榍
+榎
+榑
+榔
+榕
+榖
+榗
+榘
+榛
+榜
+榞
+榢
+榣
+榤
+榦
+榧
+榨
+榫
+榭
+榮
+榲
+榴
+榷
+榻
+榿
+槀
+槁
+槃
+槊
+構
+槌
+槍
+槎
+槐
+槓
+槔
+槗
+様
+槙
+槤
+槩
+槭
+槰
+槱
+槲
+槳
+槺
+槻
+槼
+槽
+槿
+樀
+樁
+樂
+樅
+樆
+樊
+樋
+樑
+樓
+樗
+樘
+標
+樞
+樟
+模
+樣
+樨
+権
+樫
+樵
+樸
+樹
+樺
+樻
+樽
+樾
+橄
+橇
+橈
+橋
+橐
+橒
+橓
+橘
+橙
+橚
+機
+橡
+橢
+橪
+橫
+橿
+檀
+檄
+檇
+檉
+檊
+檎
+檐
+檔
+檗
+檜
+檞
+檠
+檡
+檢
+檣
+檦
+檨
+檫
+檬
+檯
+檳
+檵
+檸
+檻
+檽
+櫂
+櫃
+櫆
+櫈
+櫓
+櫚
+櫛
+櫞
+櫟
+櫥
+櫨
+櫪
+櫱
+櫸
+櫻
+櫾
+櫿
+欄
+欉
+權
+欏
+欒
+欖
+欞
+欠
+次
+欣
+欥
+欲
+欸
+欹
+欺
+欽
+款
+歆
+歇
+歉
+歊
+歌
+歎
+歐
+歓
+歙
+歛
+歡
+止
+正
+此
+步
+武
+歧
+歩
+歪
+歲
+歳
+歴
+歷
+歸
+歹
+死
+歿
+殂
+殃
+殄
+殆
+殉
+殊
+殑
+殖
+殘
+殛
+殞
+殟
+殤
+殭
+殮
+殯
+殲
+殳
+段
+殷
+殺
+殻
+殼
+殿
+毀
+毅
+毆
+毉
+毋
+毌
+母
+毎
+每
+毐
+毒
+毓
+比
+毖
+毗
+毘
+毛
+毫
+毬
+毯
+毴
+毸
+毽
+毿
+氂
+氈
+氍
+氏
+氐
+民
+氓
+氖
+気
+氘
+氙
+氚
+氛
+氟
+氣
+氦
+氧
+氨
+氪
+氫
+氬
+氮
+氯
+氰
+水
+氵
+氷
+永
+氹
+氻
+氽
+氾
+汀
+汁
+求
+汊
+汎
+汐
+汕
+汗
+汛
+汜
+汝
+汞
+江
+池
+污
+汧
+汨
+汩
+汪
+汭
+汰
+汲
+汴
+汶
+決
+汽
+汾
+沁
+沂
+沃
+沄
+沅
+沆
+沇
+沈
+沉
+沌
+沍
+沏
+沐
+沒
+沓
+沔
+沖
+沘
+沙
+沚
+沛
+沜
+沢
+沨
+沫
+沭
+沮
+沯
+沱
+河
+沸
+油
+沺
+治
+沼
+沽
+沾
+沿
+況
+泂
+泄
+泆
+泇
+泉
+泊
+泌
+泐
+泓
+泔
+法
+泖
+泗
+泚
+泛
+泠
+泡
+波
+泣
+泥
+泩
+泫
+泮
+泯
+泰
+泱
+泳
+泵
+洄
+洋
+洌
+洎
+洗
+洙
+洛
+洞
+洢
+洣
+洤
+津
+洨
+洩
+洪
+洮
+洱
+洲
+洳
+洵
+洸
+洹
+洺
+活
+洽
+派
+流
+浄
+浙
+浚
+浛
+浜
+浞
+浟
+浠
+浡
+浣
+浤
+浥
+浦
+浩
+浪
+浮
+浯
+浴
+浵
+海
+浸
+浹
+涅
+涇
+消
+涉
+涌
+涎
+涑
+涓
+涔
+涕
+涙
+涪
+涫
+涮
+涯
+液
+涵
+涸
+涼
+涿
+淄
+淅
+淆
+淇
+淋
+淌
+淍
+淎
+淏
+淑
+淓
+淖
+淘
+淙
+淚
+淛
+淝
+淞
+淠
+淡
+淤
+淥
+淦
+淨
+淩
+淪
+淫
+淬
+淮
+淯
+淰
+深
+淳
+淵
+淶
+混
+淸
+淹
+淺
+添
+淼
+淽
+渃
+清
+済
+渉
+渋
+渕
+渙
+渚
+減
+渝
+渟
+渠
+渡
+渣
+渤
+渥
+渦
+渫
+測
+渭
+港
+渲
+渴
+游
+渺
+渼
+渽
+渾
+湃
+湄
+湉
+湊
+湍
+湓
+湔
+湖
+湘
+湛
+湜
+湞
+湟
+湣
+湥
+湧
+湫
+湮
+湯
+湳
+湴
+湼
+満
+溁
+溇
+溈
+溉
+溋
+溎
+溏
+源
+準
+溙
+溜
+溝
+溟
+溢
+溥
+溦
+溧
+溪
+溫
+溯
+溱
+溲
+溴
+溵
+溶
+溺
+溼
+滀
+滁
+滂
+滄
+滅
+滇
+滈
+滉
+滋
+滌
+滎
+滏
+滑
+滓
+滔
+滕
+滘
+滙
+滝
+滬
+滯
+滲
+滴
+滷
+滸
+滹
+滻
+滽
+滾
+滿
+漁
+漂
+漆
+漇
+漈
+漎
+漏
+漓
+演
+漕
+漚
+漠
+漢
+漣
+漩
+漪
+漫
+漬
+漯
+漱
+漲
+漳
+漴
+漵
+漷
+漸
+漼
+漾
+漿
+潁
+潑
+潔
+潘
+潛
+潞
+潟
+潢
+潤
+潭
+潮
+潯
+潰
+潲
+潺
+潼
+潽
+潾
+潿
+澀
+澁
+澂
+澄
+澆
+澇
+澈
+澉
+澋
+澌
+澍
+澎
+澔
+澗
+澠
+澡
+澣
+澤
+澥
+澧
+澪
+澮
+澯
+澱
+澳
+澶
+澹
+澻
+激
+濁
+濂
+濃
+濉
+濊
+濋
+濕
+濘
+濙
+濛
+濞
+濟
+濠
+濡
+濤
+濫
+濬
+濮
+濯
+濰
+濱
+濲
+濶
+濺
+濼
+濾
+瀁
+瀅
+瀆
+瀉
+瀍
+瀏
+瀑
+瀔
+瀕
+瀘
+瀚
+瀛
+瀝
+瀞
+瀟
+瀠
+瀣
+瀦
+瀧
+瀨
+瀬
+瀰
+瀲
+瀴
+瀶
+瀹
+瀾
+灃
+灊
+灌
+灑
+灘
+灝
+灞
+灡
+灣
+灤
+灧
+火
+灰
+灴
+灸
+灼
+災
+炁
+炅
+炆
+炊
+炎
+炒
+炔
+炕
+炘
+炙
+炟
+炣
+炤
+炫
+炬
+炭
+炮
+炯
+炱
+炲
+炳
+炷
+炸
+為
+炻
+烈
+烉
+烊
+烋
+烏
+烒
+烔
+烘
+烙
+烜
+烝
+烤
+烯
+烱
+烴
+烷
+烹
+烺
+烽
+焃
+焄
+焉
+焊
+焌
+焓
+焗
+焙
+焚
+焜
+焞
+無
+焦
+焯
+焰
+焱
+焴
+然
+焻
+焼
+焿
+煇
+煉
+煊
+煌
+煎
+煐
+煒
+煔
+煕
+煖
+煙
+煚
+煜
+煞
+煠
+煤
+煥
+煦
+照
+煨
+煩
+煬
+煮
+煲
+煳
+煵
+煶
+煸
+煽
+熄
+熅
+熇
+熈
+熊
+熏
+熒
+熔
+熖
+熗
+熘
+熙
+熜
+熟
+熠
+熤
+熥
+熨
+熬
+熯
+熱
+熲
+熳
+熵
+熹
+熺
+熼
+熾
+熿
+燁
+燃
+燄
+燈
+燉
+燊
+燎
+燏
+燐
+燒
+燔
+燕
+燘
+燙
+燚
+燜
+燝
+營
+燥
+燦
+燧
+燫
+燬
+燭
+燮
+燴
+燹
+燻
+燼
+燾
+燿
+爀
+爆
+爌
+爍
+爐
+爔
+爚
+爛
+爝
+爨
+爪
+爬
+爭
+爯
+爰
+爲
+爵
+父
+爸
+爹
+爺
+爻
+爽
+爾
+爿
+牁
+牂
+牆
+片
+版
+牌
+牒
+牕
+牖
+牘
+牙
+牛
+牝
+牟
+牠
+牡
+牢
+牧
+物
+牯
+牲
+特
+牻
+牼
+牽
+犀
+犁
+犂
+犇
+犍
+犎
+犖
+犛
+犢
+犧
+犨
+犬
+犯
+犰
+犴
+犽
+狀
+狂
+狄
+狍
+狎
+狐
+狒
+狓
+狗
+狙
+狛
+狟
+狠
+狡
+狦
+狨
+狩
+狳
+狶
+狷
+狸
+狹
+狻
+狼
+猁
+猄
+猇
+猊
+猗
+猙
+猛
+猜
+猝
+猞
+猢
+猥
+猨
+猩
+猳
+猴
+猶
+猷
+猺
+猻
+猾
+猿
+獁
+獃
+獄
+獅
+獇
+獎
+獏
+獐
+獒
+獠
+獢
+獣
+獨
+獬
+獮
+獯
+獰
+獲
+獴
+獵
+獷
+獸
+獺
+獻
+獼
+獾
+玀
+玄
+玆
+率
+玉
+王
+玎
+玏
+玓
+玕
+玖
+玗
+玘
+玙
+玟
+玠
+玡
+玢
+玥
+玧
+玨
+玩
+玫
+玭
+玲
+玳
+玶
+玷
+玹
+玻
+玾
+珀
+珂
+珅
+珈
+珉
+珊
+珌
+珍
+珎
+珏
+珖
+珙
+珝
+珞
+珠
+珡
+珣
+珤
+珥
+珦
+珧
+珩
+珪
+班
+珮
+珵
+珹
+珺
+珽
+現
+琁
+球
+琄
+琅
+理
+琇
+琉
+琊
+琍
+琎
+琚
+琛
+琡
+琢
+琤
+琥
+琦
+琨
+琪
+琬
+琮
+琯
+琰
+琱
+琳
+琴
+琵
+琶
+琹
+琺
+琿
+瑀
+瑁
+瑂
+瑄
+瑅
+瑆
+瑈
+瑊
+瑋
+瑑
+瑒
+瑕
+瑗
+瑙
+瑚
+瑛
+瑜
+瑝
+瑞
+瑟
+瑠
+瑢
+瑣
+瑤
+瑥
+瑧
+瑨
+瑩
+瑪
+瑭
+瑯
+瑰
+瑱
+瑳
+瑴
+瑺
+瑾
+璀
+璁
+璃
+璄
+璆
+璇
+璈
+璉
+璋
+璌
+璐
+璕
+璘
+璙
+璚
+璜
+璞
+璟
+璠
+璡
+璣
+璥
+璦
+璧
+璨
+璩
+璪
+璫
+璬
+璮
+環
+璱
+璵
+璸
+璹
+璽
+璿
+瓈
+瓊
+瓌
+瓏
+瓑
+瓔
+瓖
+瓘
+瓚
+瓛
+瓜
+瓞
+瓠
+瓢
+瓣
+瓤
+瓦
+瓮
+瓴
+瓶
+瓷
+瓿
+甂
+甄
+甌
+甍
+甑
+甕
+甘
+甙
+甚
+甜
+生
+甡
+產
+産
+甥
+甦
+用
+甩
+甪
+甫
+甬
+甯
+田
+由
+甲
+申
+男
+甸
+甹
+町
+甾
+畀
+畇
+畈
+畊
+畋
+界
+畎
+畏
+畐
+畑
+畔
+留
+畜
+畝
+畠
+畢
+略
+畦
+畧
+番
+畫
+畬
+畯
+異
+畲
+畳
+畵
+當
+畷
+畸
+畹
+畿
+疃
+疆
+疇
+疊
+疋
+疌
+疍
+疏
+疑
+疒
+疕
+疙
+疚
+疝
+疣
+疤
+疥
+疫
+疲
+疳
+疵
+疸
+疹
+疼
+疽
+疾
+痂
+病
+症
+痊
+痍
+痔
+痕
+痘
+痙
+痛
+痞
+痟
+痠
+痢
+痣
+痤
+痧
+痩
+痰
+痱
+痲
+痴
+痹
+痺
+痿
+瘀
+瘁
+瘊
+瘋
+瘍
+瘓
+瘙
+瘜
+瘞
+瘟
+瘠
+瘡
+瘢
+瘤
+瘦
+瘧
+瘩
+瘰
+瘴
+瘺
+癀
+療
+癆
+癇
+癌
+癒
+癖
+癘
+癜
+癟
+癡
+癢
+癤
+癥
+癩
+癬
+癭
+癮
+癯
+癰
+癱
+癲
+癸
+発
+登
+發
+白
+百
+皂
+的
+皆
+皇
+皈
+皋
+皎
+皐
+皓
+皖
+皙
+皚
+皛
+皝
+皞
+皮
+皰
+皴
+皷
+皸
+皺
+皿
+盂
+盃
+盅
+盆
+盈
+益
+盋
+盌
+盎
+盒
+盔
+盛
+盜
+盞
+盟
+盡
+監
+盤
+盥
+盦
+盧
+盨
+盩
+盪
+盫
+目
+盯
+盱
+盲
+直
+盷
+相
+盹
+盺
+盼
+盾
+眀
+省
+眉
+看
+県
+眙
+眛
+眜
+眞
+真
+眠
+眥
+眨
+眩
+眭
+眯
+眵
+眶
+眷
+眸
+眺
+眼
+眾
+着
+睇
+睛
+睜
+睞
+睡
+睢
+督
+睥
+睦
+睨
+睪
+睫
+睭
+睹
+睺
+睽
+睾
+睿
+瞄
+瞅
+瞋
+瞌
+瞎
+瞑
+瞓
+瞞
+瞢
+瞥
+瞧
+瞪
+瞫
+瞬
+瞭
+瞰
+瞳
+瞻
+瞼
+瞽
+瞿
+矇
+矍
+矗
+矚
+矛
+矜
+矞
+矢
+矣
+知
+矧
+矩
+短
+矮
+矯
+石
+矸
+矽
+砂
+砋
+砌
+砍
+砒
+研
+砝
+砢
+砥
+砦
+砧
+砩
+砫
+砭
+砮
+砯
+砰
+砲
+砳
+破
+砵
+砷
+砸
+砼
+硂
+硃
+硅
+硇
+硏
+硐
+硒
+硓
+硚
+硜
+硝
+硤
+硨
+硫
+硬
+硭
+硯
+硼
+碁
+碇
+碉
+碌
+碎
+碑
+碓
+碕
+碗
+碘
+碚
+碟
+碡
+碣
+碧
+碩
+碪
+碭
+碰
+碲
+碳
+碴
+碶
+碸
+確
+碻
+碼
+碽
+碾
+磁
+磅
+磊
+磋
+磐
+磔
+磕
+磘
+磙
+磚
+磜
+磡
+磨
+磪
+磬
+磯
+磱
+磲
+磵
+磷
+磺
+磻
+磾
+礁
+礄
+礎
+礐
+礑
+礒
+礙
+礠
+礦
+礪
+礫
+礬
+礮
+礱
+礴
+示
+礻
+礽
+社
+祀
+祁
+祂
+祆
+祇
+祈
+祉
+祋
+祏
+祐
+祓
+祕
+祖
+祗
+祙
+祚
+祛
+祜
+祝
+神
+祟
+祠
+祥
+祧
+票
+祭
+祹
+祺
+祼
+祿
+禁
+禃
+禇
+禍
+禎
+福
+禑
+禓
+禔
+禕
+禘
+禛
+禟
+禠
+禤
+禦
+禧
+禨
+禩
+禪
+禮
+禰
+禱
+禵
+禹
+禺
+禼
+禽
+禾
+禿
+秀
+私
+秈
+秉
+秋
+科
+秒
+秕
+秘
+租
+秠
+秣
+秤
+秦
+秧
+秩
+秭
+秳
+秸
+移
+稀
+稅
+稈
+稉
+程
+稍
+稑
+稔
+稗
+稘
+稙
+稚
+稜
+稞
+稟
+稠
+種
+稱
+稲
+稷
+稹
+稺
+稻
+稼
+稽
+稾
+稿
+穀
+穂
+穆
+穈
+穉
+穌
+積
+穎
+穗
+穟
+穠
+穡
+穢
+穣
+穩
+穫
+穰
+穴
+穵
+究
+穹
+空
+穿
+突
+窄
+窅
+窈
+窋
+窒
+窕
+窖
+窗
+窘
+窟
+窠
+窣
+窨
+窩
+窪
+窮
+窯
+窰
+窶
+窺
+窿
+竄
+竅
+竇
+竈
+竊
+立
+竑
+站
+竜
+竟
+章
+竣
+童
+竦
+竩
+竭
+端
+競
+竹
+竺
+竻
+竿
+笄
+笆
+笈
+笏
+笑
+笘
+笙
+笛
+笞
+笠
+笥
+符
+笨
+笩
+笪
+第
+笭
+笮
+笯
+笱
+笳
+笹
+筅
+筆
+等
+筊
+筋
+筌
+筍
+筏
+筐
+筒
+答
+策
+筘
+筠
+筥
+筦
+筧
+筬
+筭
+筱
+筲
+筳
+筵
+筶
+筷
+筻
+箆
+箇
+箋
+箍
+箏
+箐
+箑
+箒
+箔
+箕
+算
+箜
+管
+箬
+箭
+箱
+箴
+箸
+節
+篁
+範
+篆
+篇
+築
+篊
+篋
+篌
+篔
+篙
+篝
+篠
+篡
+篤
+篥
+篦
+篩
+篪
+篭
+篯
+篳
+篷
+簀
+簃
+簇
+簉
+簋
+簍
+簑
+簕
+簗
+簞
+簠
+簡
+簧
+簪
+簫
+簷
+簸
+簹
+簺
+簽
+簾
+簿
+籀
+籃
+籌
+籍
+籐
+籙
+籛
+籜
+籝
+籟
+籠
+籣
+籤
+籥
+籪
+籬
+籮
+籲
+米
+籽
+籾
+粄
+粉
+粍
+粑
+粒
+粕
+粗
+粘
+粟
+粢
+粥
+粦
+粧
+粩
+粱
+粲
+粳
+粵
+粹
+粼
+粽
+精
+粿
+糀
+糅
+糊
+糌
+糍
+糎
+糕
+糖
+糙
+糜
+糝
+糞
+糟
+糠
+糢
+糧
+糬
+糯
+糰
+糴
+糶
+糸
+糹
+糺
+系
+糾
+紀
+紂
+約
+紅
+紆
+紇
+紈
+紉
+紊
+紋
+納
+紐
+紑
+紓
+純
+紕
+紗
+紘
+紙
+級
+紛
+紜
+紝
+紞
+素
+紡
+索
+紫
+紮
+累
+細
+紱
+紲
+紳
+紵
+紹
+紺
+紿
+終
+絃
+組
+絆
+経
+絎
+結
+絕
+絛
+絜
+絞
+絡
+絢
+給
+絨
+絪
+絮
+統
+絲
+絳
+絵
+絶
+絹
+絺
+綁
+綃
+綈
+綉
+綎
+綏
+經
+綖
+継
+続
+綜
+綝
+綞
+綠
+綢
+綣
+綦
+綧
+綫
+綬
+維
+綮
+綰
+綱
+網
+綳
+綴
+綸
+綺
+綻
+綽
+綾
+綿
+緁
+緃
+緄
+緈
+緊
+緋
+総
+緑
+緒
+緖
+緘
+線
+緜
+緝
+緞
+締
+緡
+緣
+緤
+編
+緩
+緬
+緯
+緱
+緲
+練
+緹
+緻
+縂
+縄
+縈
+縉
+縊
+縕
+縛
+縝
+縞
+縠
+縡
+縣
+縤
+縫
+縮
+縯
+縱
+縴
+縵
+縷
+縹
+縻
+總
+績
+繁
+繃
+繆
+繇
+繒
+織
+繕
+繖
+繙
+繚
+繞
+繡
+繩
+繪
+繫
+繭
+繰
+繳
+繹
+繻
+繼
+繽
+繾
+纁
+纂
+纈
+續
+纍
+纏
+纓
+纔
+纕
+纖
+纘
+纛
+纜
+缐
+缶
+缸
+缺
+缽
+罃
+罄
+罅
+罈
+罉
+罌
+罍
+罐
+罔
+罕
+罘
+罟
+罡
+罨
+罩
+罪
+置
+罰
+罱
+署
+罳
+罵
+罶
+罷
+罹
+罽
+羂
+羅
+羆
+羈
+羊
+羋
+羌
+美
+羔
+羕
+羗
+羙
+羚
+羞
+羡
+羣
+群
+羥
+羧
+羨
+義
+羯
+羰
+羱
+羲
+羸
+羹
+羽
+羿
+翀
+翁
+翂
+翃
+翅
+翊
+翌
+翎
+翏
+習
+翔
+翕
+翙
+翜
+翟
+翠
+翡
+翥
+翦
+翩
+翬
+翮
+翰
+翱
+翳
+翹
+翻
+翼
+耀
+老
+考
+耄
+者
+耆
+而
+耍
+耎
+耐
+耑
+耒
+耔
+耕
+耗
+耘
+耙
+耜
+耦
+耨
+耬
+耳
+耵
+耶
+耷
+耽
+耿
+聃
+聆
+聊
+聒
+聖
+聘
+聚
+聞
+聟
+聨
+聯
+聰
+聱
+聲
+聳
+聴
+聶
+職
+聽
+聾
+聿
+肄
+肅
+肆
+肇
+肉
+肋
+肌
+肏
+肖
+肘
+肚
+肛
+肜
+肝
+肟
+股
+肢
+肥
+肩
+肪
+肫
+肯
+肱
+育
+肸
+肹
+肺
+肼
+肽
+胂
+胃
+胄
+胅
+胇
+胊
+背
+胍
+胎
+胖
+胗
+胙
+胚
+胛
+胝
+胞
+胡
+胤
+胥
+胬
+胭
+胰
+胱
+胳
+胴
+胸
+胺
+胼
+能
+脂
+脅
+脆
+脇
+脈
+脊
+脒
+脖
+脘
+脛
+脣
+脩
+脫
+脬
+脭
+脯
+脲
+脳
+脷
+脹
+脾
+腆
+腈
+腊
+腋
+腌
+腎
+腐
+腑
+腓
+腔
+腕
+腥
+腦
+腧
+腩
+腫
+腮
+腰
+腱
+腳
+腴
+腸
+腹
+腺
+腿
+膀
+膂
+膈
+膊
+膏
+膚
+膛
+膜
+膝
+膠
+膣
+膥
+膦
+膨
+膩
+膮
+膳
+膺
+膽
+膾
+膿
+臀
+臂
+臃
+臆
+臉
+臊
+臍
+臏
+臘
+臚
+臞
+臟
+臠
+臣
+臧
+臨
+自
+臭
+臯
+至
+致
+臺
+臻
+臼
+臾
+舂
+舅
+與
+興
+舉
+舊
+舌
+舍
+舎
+舒
+舔
+舖
+舘
+舛
+舜
+舞
+舟
+舢
+舥
+舨
+舩
+航
+舫
+般
+舲
+舵
+舶
+舷
+舸
+船
+舺
+艅
+艇
+艉
+艋
+艎
+艏
+艔
+艘
+艙
+艚
+艦
+艮
+良
+艱
+色
+艶
+艷
+艸
+艽
+艾
+艿
+芃
+芊
+芋
+芍
+芎
+芑
+芒
+芘
+芙
+芛
+芝
+芡
+芥
+芨
+芩
+芪
+芫
+芬
+芭
+芮
+芯
+花
+芳
+芴
+芷
+芸
+芹
+芻
+芽
+芾
+苄
+苅
+苑
+苒
+苓
+苔
+苕
+苗
+苛
+苜
+苝
+苞
+苟
+苡
+苣
+苤
+若
+苦
+苧
+苪
+苫
+苯
+英
+苳
+苴
+苷
+苺
+苻
+苼
+苾
+茀
+茁
+茂
+范
+茄
+茅
+茆
+茇
+茈
+茉
+茌
+茗
+茘
+茚
+茛
+茜
+茝
+茨
+茫
+茬
+茭
+茮
+茯
+茱
+茲
+茴
+茵
+茶
+茷
+茸
+茹
+茺
+茼
+荀
+荃
+荅
+荇
+草
+荊
+荎
+荏
+荒
+荔
+荖
+荘
+荳
+荷
+荸
+荻
+荼
+荽
+莆
+莉
+莊
+莎
+莒
+莓
+莕
+莖
+莘
+莙
+莛
+莜
+莞
+莠
+莢
+莧
+莨
+莩
+莪
+莫
+莽
+莿
+菀
+菁
+菅
+菇
+菈
+菉
+菊
+菌
+菍
+菏
+菑
+菓
+菔
+菖
+菘
+菜
+菝
+菟
+菠
+菡
+菥
+菩
+菪
+菫
+華
+菰
+菱
+菲
+菴
+菶
+菸
+菹
+菺
+菼
+菽
+菾
+萁
+萃
+萄
+萇
+萊
+萌
+萍
+萎
+萐
+萘
+萜
+萠
+萡
+萣
+萩
+萬
+萭
+萱
+萵
+萸
+萹
+萼
+落
+葃
+葆
+葉
+葊
+葎
+葑
+葒
+著
+葙
+葚
+葛
+葜
+葝
+葡
+董
+葦
+葩
+葫
+葬
+葭
+葯
+葰
+葳
+葵
+葶
+葷
+葺
+蒂
+蒄
+蒍
+蒎
+蒐
+蒓
+蒔
+蒗
+蒙
+蒜
+蒞
+蒟
+蒡
+蒢
+蒤
+蒧
+蒨
+蒭
+蒯
+蒲
+蒴
+蒸
+蒹
+蒺
+蒻
+蒼
+蒽
+蒾
+蒿
+蓀
+蓁
+蓂
+蓄
+蓆
+蓉
+蓋
+蓍
+蓑
+蓓
+蓖
+蓘
+蓚
+蓧
+蓨
+蓪
+蓬
+蓭
+蓮
+蓯
+蓳
+蓼
+蓽
+蓿
+蔆
+蔎
+蔑
+蔓
+蔔
+蔕
+蔗
+蔘
+蔚
+蔝
+蔞
+蔡
+蔣
+蔥
+蔦
+蔬
+蔭
+蔴
+蔵
+蔻
+蔽
+蕁
+蕃
+蕅
+蕈
+蕉
+蕊
+蕎
+蕑
+蕒
+蕖
+蕘
+蕙
+蕚
+蕟
+蕡
+蕢
+蕤
+蕨
+蕩
+蕪
+蕭
+蕷
+蕹
+蕺
+蕻
+蕾
+薀
+薄
+薆
+薇
+薈
+薊
+薌
+薏
+薐
+薑
+薔
+薗
+薘
+薙
+薛
+薜
+薞
+薟
+薡
+薦
+薨
+薩
+薪
+薫
+薬
+薯
+薰
+薲
+薷
+薸
+薹
+薺
+薾
+薿
+藁
+藉
+藍
+藎
+藏
+藐
+藔
+藕
+藜
+藝
+藟
+藤
+藥
+藦
+藨
+藩
+藪
+藶
+藸
+藹
+藺
+藻
+藿
+蘂
+蘄
+蘅
+蘆
+蘇
+蘊
+蘋
+蘐
+蘑
+蘓
+蘗
+蘘
+蘚
+蘞
+蘢
+蘧
+蘩
+蘭
+蘵
+蘶
+蘸
+蘼
+蘿
+虉
+虎
+虐
+虓
+虔
+處
+虖
+虛
+虜
+虞
+號
+虢
+虧
+虨
+虯
+虱
+虵
+虹
+虺
+虻
+蚆
+蚊
+蚋
+蚌
+蚍
+蚓
+蚖
+蚜
+蚝
+蚡
+蚢
+蚣
+蚤
+蚧
+蚨
+蚩
+蚪
+蚯
+蚱
+蚴
+蚵
+蚶
+蚺
+蚼
+蛀
+蛄
+蛇
+蛉
+蛋
+蛍
+蛐
+蛑
+蛔
+蛙
+蛛
+蛞
+蛟
+蛤
+蛭
+蛯
+蛸
+蛹
+蛺
+蛻
+蛾
+蜀
+蜂
+蜃
+蜆
+蜇
+蜈
+蜉
+蜊
+蜍
+蜑
+蜒
+蜓
+蜘
+蜚
+蜛
+蜜
+蜞
+蜢
+蜣
+蜥
+蜨
+蜮
+蜯
+蜱
+蜴
+蜷
+蜻
+蜾
+蜿
+蝀
+蝌
+蝍
+蝎
+蝓
+蝕
+蝗
+蝘
+蝙
+蝚
+蝟
+蝠
+蝣
+蝤
+蝦
+蝨
+蝮
+蝯
+蝰
+蝲
+蝴
+蝶
+蝸
+蝽
+螂
+螃
+螄
+螅
+螈
+螋
+融
+螐
+螔
+螞
+螟
+螠
+螢
+螣
+螥
+螫
+螭
+螯
+螳
+螶
+螺
+螻
+螽
+螾
+蟀
+蟄
+蟅
+蟆
+蟊
+蟋
+蟌
+蟎
+蟑
+蟒
+蟜
+蟠
+蟥
+蟪
+蟫
+蟬
+蟯
+蟲
+蟳
+蟴
+蟶
+蟹
+蟻
+蟾
+蠂
+蠃
+蠄
+蠅
+蠆
+蠊
+蠋
+蠍
+蠐
+蠑
+蠓
+蠔
+蠕
+蠖
+蠘
+蠙
+蠟
+蠡
+蠢
+蠣
+蠱
+蠲
+蠵
+蠶
+蠷
+蠹
+蠻
+血
+衂
+衆
+行
+衍
+衎
+術
+衕
+衖
+街
+衙
+衚
+衛
+衜
+衝
+衞
+衡
+衢
+衣
+表
+衩
+衫
+衰
+衲
+衷
+衽
+衾
+衿
+袁
+袂
+袈
+袋
+袍
+袓
+袖
+袛
+袞
+袤
+袪
+被
+袱
+袴
+袾
+裁
+裂
+裊
+裎
+裒
+裔
+裕
+裖
+裘
+裙
+補
+裝
+裟
+裡
+裨
+裬
+裱
+裳
+裴
+裵
+裸
+裹
+製
+裾
+裿
+褀
+褂
+複
+褌
+褍
+褎
+褐
+褒
+褓
+褔
+褘
+褙
+褚
+褞
+褥
+褧
+褪
+褫
+褭
+褲
+褶
+褸
+褻
+襄
+襌
+襖
+襞
+襟
+襠
+襤
+襦
+襪
+襯
+襲
+襴
+襶
+襻
+襾
+西
+要
+覃
+覆
+覇
+覈
+見
+覌
+規
+覓
+視
+覚
+覡
+覦
+覧
+親
+覬
+覲
+観
+覺
+覽
+覿
+觀
+角
+觔
+觙
+觚
+觜
+解
+觭
+觱
+觴
+觶
+觸
+觿
+言
+訁
+訂
+訃
+訇
+計
+訊
+訌
+討
+訏
+訐
+訒
+訓
+訔
+訕
+訖
+託
+記
+訛
+訝
+訟
+訣
+訥
+訪
+設
+許
+訴
+訶
+診
+註
+証
+訾
+詁
+詆
+詈
+詐
+詒
+詔
+評
+詛
+詞
+詠
+詡
+詢
+詣
+詥
+試
+詧
+詩
+詫
+詭
+詮
+詰
+話
+該
+詳
+詵
+詹
+詼
+誄
+誅
+誇
+誌
+認
+誒
+誓
+誕
+誘
+語
+誠
+誡
+誣
+誤
+誥
+誦
+誨
+說
+説
+読
+誰
+課
+誴
+誹
+誼
+誾
+調
+談
+請
+諍
+諏
+諒
+論
+諗
+諜
+諟
+諠
+諡
+諤
+諦
+諧
+諪
+諫
+諭
+諮
+諱
+諲
+諳
+諴
+諶
+諷
+諸
+諺
+諼
+諾
+謀
+謁
+謂
+謄
+謇
+謊
+謌
+謎
+謏
+謐
+謔
+謖
+謗
+謙
+謚
+講
+謜
+謝
+謠
+謢
+謤
+謨
+謩
+謫
+謬
+謳
+謹
+謾
+證
+譏
+譓
+譔
+識
+譙
+譚
+譜
+譞
+警
+譫
+譬
+譭
+譯
+議
+譲
+譳
+譴
+護
+譽
+譿
+讀
+讃
+變
+讌
+讎
+讓
+讖
+讙
+讚
+讜
+讞
+谷
+谿
+豁
+豆
+豇
+豈
+豉
+豊
+豌
+豎
+豐
+豔
+豕
+豚
+象
+豢
+豨
+豪
+豫
+豬
+豳
+豸
+豹
+豺
+豿
+貂
+貅
+貉
+貊
+貌
+貐
+貒
+貓
+貔
+貘
+貝
+貞
+負
+財
+貢
+貤
+貧
+貨
+販
+貪
+貫
+責
+貭
+貮
+貯
+貲
+貳
+貴
+貶
+買
+貸
+貺
+費
+貼
+貽
+貿
+賀
+賁
+賂
+賃
+賄
+資
+賈
+賊
+賑
+賒
+賓
+賔
+賕
+賚
+賜
+賞
+賠
+賡
+賢
+賣
+賤
+賦
+賨
+質
+賬
+賭
+賴
+賹
+賺
+賻
+購
+賽
+賾
+贄
+贅
+贇
+贈
+贊
+贌
+贍
+贏
+贓
+贔
+贖
+贛
+赤
+赦
+赧
+赫
+赬
+赭
+走
+赳
+赴
+起
+趁
+超
+越
+趐
+趕
+趖
+趙
+趟
+趣
+趨
+足
+趴
+趵
+趺
+趼
+趾
+跅
+跆
+跋
+跌
+跏
+跑
+跖
+跗
+跛
+距
+跟
+跡
+跣
+跤
+跨
+跩
+跪
+路
+跳
+踎
+踏
+踐
+踝
+踞
+踢
+踩
+踰
+踴
+踹
+踺
+蹂
+蹄
+蹇
+蹈
+蹉
+蹊
+蹋
+蹕
+蹙
+蹟
+蹠
+蹤
+蹦
+蹬
+蹭
+蹯
+蹲
+蹴
+蹶
+蹺
+蹻
+蹼
+躁
+躂
+躄
+躉
+躋
+躍
+躑
+躒
+躔
+躝
+躪
+身
+躬
+躰
+躲
+躺
+軀
+車
+軋
+軌
+軍
+軎
+軒
+軔
+軛
+軟
+転
+軫
+軲
+軸
+軹
+軺
+軻
+軼
+軽
+軾
+較
+輄
+輅
+載
+輋
+輒
+輓
+輔
+輕
+輛
+輝
+輞
+輟
+輥
+輦
+輩
+輪
+輬
+輭
+輯
+輶
+輸
+輻
+輾
+輿
+轀
+轂
+轄
+轅
+轆
+轉
+轍
+轎
+轘
+轝
+轟
+轤
+辛
+辜
+辟
+辣
+辦
+辧
+辨
+辭
+辮
+辯
+辰
+辱
+農
+辵
+辺
+辻
+込
+迂
+迄
+迅
+迎
+近
+返
+迢
+迤
+迥
+迦
+迪
+迫
+迭
+迮
+述
+迴
+迵
+迷
+迸
+迺
+追
+退
+送
+逃
+逄
+逅
+逆
+逈
+逋
+逌
+逍
+逎
+透
+逐
+逑
+途
+逕
+逖
+逗
+這
+通
+逛
+逝
+逞
+速
+造
+逢
+連
+逤
+逨
+逮
+逯
+進
+逴
+逵
+逸
+逹
+逺
+逼
+逾
+遁
+遂
+遄
+遇
+遊
+運
+遍
+過
+遏
+遐
+遒
+道
+達
+違
+遘
+遙
+遛
+遜
+遞
+遠
+遢
+遣
+遨
+適
+遭
+遮
+遯
+遲
+遴
+遵
+遶
+遷
+選
+遹
+遺
+遼
+避
+邀
+邁
+邂
+邃
+還
+邇
+邈
+邉
+邊
+邋
+邏
+邑
+邕
+邗
+邙
+邛
+邠
+邡
+邢
+那
+邦
+邨
+邪
+邯
+邰
+邱
+邲
+邳
+邴
+邵
+邸
+邽
+邾
+郁
+郃
+郄
+郅
+郇
+郊
+郋
+郎
+郗
+郛
+郜
+郝
+郞
+郟
+郡
+郢
+郤
+部
+郪
+郫
+郭
+郯
+郳
+郴
+郵
+郷
+都
+郾
+郿
+鄂
+鄃
+鄄
+鄆
+鄉
+鄋
+鄑
+鄒
+鄔
+鄖
+鄗
+鄘
+鄙
+鄚
+鄜
+鄞
+鄠
+鄢
+鄣
+鄤
+鄧
+鄩
+鄫
+鄭
+鄯
+鄰
+鄱
+鄲
+鄳
+鄴
+鄺
+酃
+酆
+酈
+酉
+酊
+酋
+酌
+配
+酎
+酏
+酐
+酒
+酔
+酗
+酚
+酞
+酡
+酢
+酣
+酥
+酩
+酪
+酬
+酮
+酯
+酰
+酴
+酵
+酶
+酷
+酸
+酺
+酼
+醁
+醂
+醃
+醅
+醇
+醉
+醋
+醌
+醍
+醐
+醒
+醚
+醛
+醜
+醞
+醢
+醣
+醪
+醫
+醬
+醮
+醯
+醴
+醺
+醾
+醿
+釀
+釁
+釆
+采
+釉
+釋
+里
+重
+野
+量
+釐
+金
+釒
+釓
+釔
+釕
+釗
+釘
+釙
+釚
+釜
+針
+釣
+釤
+釦
+釧
+釩
+釪
+釭
+釴
+釵
+釷
+釹
+釺
+鈀
+鈁
+鈄
+鈇
+鈈
+鈉
+鈊
+鈍
+鈏
+鈐
+鈑
+鈔
+鈕
+鈖
+鈞
+鈢
+鈣
+鈥
+鈦
+鈫
+鈮
+鈰
+鈳
+鈴
+鈷
+鈸
+鈹
+鈺
+鈾
+鈿
+鉀
+鉄
+鉅
+鉆
+鉈
+鉉
+鉋
+鉌
+鉍
+鉏
+鉑
+鉓
+鉗
+鉚
+鉛
+鉞
+鉟
+鉤
+鉦
+鉬
+鉭
+鉲
+鉶
+鉷
+鉸
+鉻
+鉾
+鉿
+銀
+銂
+銃
+銅
+銋
+銍
+銑
+銓
+銕
+銖
+銘
+銚
+銜
+銠
+銣
+銥
+銦
+銨
+銩
+銪
+銫
+銬
+銭
+銱
+銲
+銳
+銶
+銷
+銹
+銻
+銼
+銾
+鋁
+鋅
+鋆
+鋇
+鋌
+鋏
+鋐
+鋒
+鋕
+鋗
+鋙
+鋡
+鋤
+鋥
+鋦
+鋨
+鋪
+鋮
+鋯
+鋰
+鋱
+鋳
+鋶
+鋸
+鋹
+鋼
+錀
+錄
+錏
+錐
+錒
+錕
+錘
+錚
+錞
+錟
+錠
+錡
+錢
+錦
+錨
+錫
+錬
+錮
+錯
+錳
+錶
+錸
+錻
+鍀
+鍇
+鍈
+鍉
+鍊
+鍋
+鍍
+鍏
+鍔
+鍘
+鍛
+鍝
+鍟
+鍠
+鍥
+鍩
+鍬
+鍱
+鍳
+鍵
+鍶
+鍷
+鍺
+鍼
+鍾
+鎂
+鎅
+鎊
+鎌
+鎏
+鎓
+鎔
+鎖
+鎗
+鎘
+鎚
+鎛
+鎢
+鎣
+鎦
+鎧
+鎪
+鎬
+鎭
+鎮
+鎰
+鎳
+鎵
+鎻
+鏃
+鏇
+鏈
+鏊
+鏌
+鏐
+鏑
+鏓
+鏖
+鏗
+鏘
+鏜
+鏝
+鏞
+鏟
+鏡
+鏢
+鏤
+鏦
+鏳
+鏴
+鏵
+鏷
+鏻
+鏽
+鐃
+鐇
+鐈
+鐓
+鐔
+鐘
+鐙
+鐠
+鐡
+鐤
+鐦
+鐧
+鐫
+鐬
+鐭
+鐮
+鐲
+鐳
+鐵
+鐸
+鐺
+鐽
+鐿
+鑀
+鑁
+鑂
+鑄
+鑅
+鑊
+鑌
+鑑
+鑒
+鑛
+鑠
+鑣
+鑨
+鑪
+鑫
+鑭
+鑰
+鑲
+鑴
+鑷
+鑼
+鑽
+鑾
+鑿
+長
+門
+閂
+閃
+閆
+閉
+開
+閎
+閏
+閑
+閒
+間
+閔
+閘
+閜
+閞
+閟
+関
+閣
+閥
+閦
+閨
+閩
+閬
+閭
+閰
+閱
+閶
+閹
+閻
+閼
+閾
+閿
+闆
+闇
+闈
+闊
+闋
+闌
+闍
+闐
+闓
+闔
+闕
+闖
+闘
+關
+闞
+闡
+闢
+闥
+阜
+阝
+阡
+阪
+阭
+阮
+阯
+阱
+防
+阻
+阿
+陀
+陁
+陂
+附
+陋
+陌
+降
+限
+陔
+陘
+陛
+陜
+陝
+陞
+陟
+陡
+院
+陣
+除
+陪
+陬
+陰
+陲
+陳
+陵
+陶
+陷
+陸
+険
+陽
+隄
+隅
+隆
+隈
+隊
+隋
+隍
+階
+隔
+隕
+隗
+隘
+隙
+際
+障
+隣
+隧
+隨
+險
+隰
+隱
+隲
+隳
+隴
+隷
+隸
+隹
+隻
+隼
+雀
+雁
+雄
+雅
+集
+雇
+雉
+雋
+雌
+雍
+雎
+雑
+雒
+雕
+雖
+雙
+雛
+雜
+雝
+雞
+離
+難
+雨
+雩
+雪
+雫
+雯
+雱
+雲
+零
+雷
+雹
+電
+需
+霄
+霅
+霆
+震
+霈
+霉
+霊
+霍
+霎
+霏
+霑
+霓
+霖
+霙
+霜
+霞
+霤
+霧
+霨
+霰
+露
+霶
+霸
+霹
+霽
+霾
+靁
+靂
+靄
+靈
+靉
+靑
+青
+靖
+靚
+靛
+靜
+非
+靠
+靡
+面
+革
+靫
+靬
+靭
+靳
+靴
+靶
+靺
+靼
+鞅
+鞆
+鞋
+鞍
+鞏
+鞘
+鞞
+鞠
+鞣
+鞥
+鞦
+鞨
+鞭
+鞮
+鞴
+韁
+韃
+韆
+韋
+韌
+韑
+韓
+韙
+韜
+韞
+韠
+韡
+韭
+韮
+音
+韶
+韺
+韻
+韾
+響
+頁
+頂
+頃
+項
+順
+須
+頊
+頌
+頍
+頎
+頏
+預
+頑
+頒
+頓
+頔
+頗
+領
+頜
+頠
+頡
+頤
+頦
+頫
+頭
+頰
+頴
+頵
+頷
+頸
+頹
+頻
+頼
+顆
+題
+額
+顎
+顏
+顒
+顓
+顔
+顕
+顗
+願
+顙
+顛
+類
+顥
+顧
+顫
+顯
+顰
+顱
+顳
+顴
+風
+颮
+颯
+颱
+颶
+颺
+颼
+飄
+飆
+飈
+飛
+食
+飠
+飡
+飢
+飥
+飩
+飪
+飫
+飬
+飭
+飮
+飯
+飲
+飴
+飼
+飽
+飾
+餃
+餄
+餅
+餉
+養
+餌
+餎
+餐
+餒
+餓
+餗
+餘
+餚
+餛
+餞
+餠
+餡
+館
+餮
+餵
+餺
+餾
+餿
+饃
+饅
+饋
+饌
+饑
+饒
+饕
+饗
+饞
+饟
+饢
+首
+馗
+馘
+香
+馛
+馥
+馦
+馨
+馬
+馭
+馮
+馯
+馱
+馳
+馴
+馼
+駁
+駄
+駅
+駆
+駐
+駑
+駒
+駔
+駕
+駘
+駙
+駛
+駝
+駟
+駢
+駭
+駰
+駱
+駿
+騁
+騂
+騄
+騅
+騋
+騎
+騏
+験
+騖
+騙
+騤
+騨
+騫
+騭
+騮
+騰
+騶
+騷
+騾
+驁
+驃
+驄
+驅
+驊
+驌
+驍
+驎
+驒
+驕
+驗
+驚
+驛
+驟
+驢
+驤
+驥
+驩
+驪
+骨
+骯
+骰
+骶
+骷
+骸
+骼
+髀
+髂
+髎
+髏
+髑
+髒
+髓
+體
+高
+髙
+髡
+髦
+髪
+髭
+髮
+髯
+髲
+髷
+髹
+髻
+鬃
+鬄
+鬅
+鬆
+鬍
+鬚
+鬟
+鬢
+鬣
+鬥
+鬧
+鬨
+鬩
+鬪
+鬬
+鬮
+鬯
+鬱
+鬲
+鬹
+鬻
+鬼
+魁
+魂
+魃
+魄
+魅
+魈
+魋
+魍
+魎
+魏
+魔
+魕
+魘
+魚
+魛
+魞
+魟
+魣
+魨
+魩
+魮
+魯
+魴
+魷
+鮀
+鮁
+鮃
+鮄
+鮊
+鮋
+鮍
+鮐
+鮑
+鮒
+鮓
+鮗
+鮜
+鮟
+鮠
+鮡
+鮣
+鮨
+鮪
+鮫
+鮭
+鮮
+鮰
+鮸
+鮹
+鮻
+鯀
+鯁
+鯃
+鯇
+鯉
+鯊
+鯏
+鯒
+鯓
+鯔
+鯕
+鯖
+鯗
+鯙
+鯛
+鯡
+鯢
+鯤
+鯧
+鯨
+鯪
+鯭
+鯮
+鯰
+鯶
+鯷
+鯻
+鯽
+鯿
+鰂
+鰃
+鰆
+鰈
+鰉
+鰍
+鰏
+鰒
+鰓
+鰕
+鰗
+鰛
+鰜
+鰟
+鰣
+鰤
+鰧
+鰨
+鰩
+鰭
+鰮
+鰱
+鰲
+鰳
+鰶
+鰷
+鰹
+鰺
+鰻
+鰼
+鰾
+鱀
+鱂
+鱅
+鱇
+鱈
+鱉
+鱊
+鱒
+鱓
+鱔
+鱖
+鱗
+鱘
+鱚
+鱝
+鱟
+鱠
+鱣
+鱥
+鱧
+鱨
+鱬
+鱮
+鱰
+鱲
+鱵
+鱷
+鱸
+鱺
+鱻
+鳥
+鳧
+鳩
+鳯
+鳰
+鳳
+鳴
+鳶
+鳽
+鴆
+鴇
+鴉
+鴒
+鴓
+鴕
+鴗
+鴛
+鴝
+鴞
+鴟
+鴡
+鴣
+鴦
+鴨
+鴫
+鴯
+鴰
+鴴
+鴻
+鴿
+鵂
+鵄
+鵎
+鵐
+鵑
+鵒
+鵓
+鵙
+鵜
+鵝
+鵞
+鵟
+鵠
+鵡
+鵪
+鵬
+鵯
+鵰
+鵲
+鵵
+鵼
+鵾
+鶆
+鶇
+鶉
+鶏
+鶒
+鶓
+鶘
+鶚
+鶡
+鶥
+鶩
+鶬
+鶯
+鶲
+鶴
+鶹
+鶺
+鶻
+鶼
+鶿
+鷂
+鷄
+鷉
+鷎
+鷓
+鷗
+鷙
+鷚
+鷟
+鷥
+鷦
+鷫
+鷯
+鷲
+鷳
+鷸
+鷹
+鷺
+鸊
+鸌
+鸐
+鸑
+鸕
+鸘
+鸚
+鸛
+鸜
+鸝
+鸞
+鹮
+鹵
+鹹
+鹼
+鹽
+鹿
+麂
+麅
+麇
+麈
+麊
+麋
+麐
+麒
+麓
+麗
+麝
+麞
+麟
+麥
+麩
+麪
+麯
+麴
+麵
+麹
+麺
+麻
+麼
+麽
+麾
+麿
+黁
+黃
+黇
+黌
+黍
+黎
+黏
+黐
+黑
+黒
+黔
+默
+黙
+黛
+黜
+黝
+點
+黟
+黥
+黧
+黨
+黯
+黴
+黶
+黻
+黼
+黽
+黿
+鼂
+鼇
+鼈
+鼉
+鼎
+鼐
+鼒
+鼓
+鼕
+鼙
+鼠
+鼢
+鼩
+鼬
+鼯
+鼱
+鼴
+鼷
+鼻
+鼽
+鼾
+齊
+齋
+齒
+齕
+齡
+齣
+齦
+齧
+齲
+齶
+龍
+龎
+龐
+龑
+龔
+龕
+龜
+龝
+龠
+龢
+郎
+凉
+﹑
+﹗
+﹝
+﹞
+﹢
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+A
+B
+C
+D
+E
+F
+G
+H
+I
+K
+L
+M
+N
+O
+P
+R
+S
+T
+U
+V
+W
+Y
+Z
+[
+]
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+r
+s
+t
+u
+z
+{
+|
+}
+~
+¥
+𣇉
+
diff --git a/tools/utils/dict/confuse.pkl b/tools/utils/dict/confuse.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..83c8a74b9be7873bb89997e483106bbc924963bd
--- /dev/null
+++ b/tools/utils/dict/confuse.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:57d8aa98a781330f7534a81263179e77c159a37050bd31fe09d8ddf1880c5628
+size 30912
diff --git a/tools/utils/dict/cyrillic_dict.txt b/tools/utils/dict/cyrillic_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2b6f66494d5417e18bbd225719aa72690e09e126
--- /dev/null
+++ b/tools/utils/dict/cyrillic_dict.txt
@@ -0,0 +1,163 @@
+
+!
+#
+$
+%
+&
+'
+(
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+É
+é
+Ё
+Є
+І
+Ј
+Љ
+Ў
+А
+Б
+В
+Г
+Д
+Е
+Ж
+З
+И
+Й
+К
+Л
+М
+Н
+О
+П
+Р
+С
+Т
+У
+Ф
+Х
+Ц
+Ч
+Ш
+Щ
+Ъ
+Ы
+Ь
+Э
+Ю
+Я
+а
+б
+в
+г
+д
+е
+ж
+з
+и
+й
+к
+л
+м
+н
+о
+п
+р
+с
+т
+у
+ф
+х
+ц
+ч
+ш
+щ
+ъ
+ы
+ь
+э
+ю
+я
+ё
+ђ
+є
+і
+ј
+љ
+њ
+ћ
+ў
+џ
+Ґ
+ґ
diff --git a/tools/utils/dict/devanagari_dict.txt b/tools/utils/dict/devanagari_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f55923061bfd480b875bb3679d7a75a9157387a9
--- /dev/null
+++ b/tools/utils/dict/devanagari_dict.txt
@@ -0,0 +1,167 @@
+
+!
+#
+$
+%
+&
+'
+(
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+É
+é
+ँ
+ं
+ः
+अ
+आ
+इ
+ई
+उ
+ऊ
+ऋ
+ए
+ऐ
+ऑ
+ओ
+औ
+क
+ख
+ग
+घ
+ङ
+च
+छ
+ज
+झ
+ञ
+ट
+ठ
+ड
+ढ
+ण
+त
+थ
+द
+ध
+न
+ऩ
+प
+फ
+ब
+भ
+म
+य
+र
+ऱ
+ल
+ळ
+व
+श
+ष
+स
+ह
+़
+ा
+ि
+ी
+ु
+ू
+ृ
+ॅ
+े
+ै
+ॉ
+ो
+ौ
+्
+॒
+क़
+ख़
+ग़
+ज़
+ड़
+ढ़
+फ़
+ॠ
+।
+०
+१
+२
+३
+४
+५
+६
+७
+८
+९
+॰
diff --git a/tools/utils/dict/en_dict.txt b/tools/utils/dict/en_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6fbd99f46acca8391a5e86ae546c637399204506
--- /dev/null
+++ b/tools/utils/dict/en_dict.txt
@@ -0,0 +1,63 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+
diff --git a/tools/utils/dict/fa_dict.txt b/tools/utils/dict/fa_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2328fbd8374b3c551036a8521c1a70104925b5a8
--- /dev/null
+++ b/tools/utils/dict/fa_dict.txt
@@ -0,0 +1,136 @@
+f
+a
+_
+i
+m
+g
+/
+1
+3
+I
+L
+S
+V
+R
+C
+2
+0
+v
+l
+6
+8
+5
+.
+j
+p
+و
+د
+ر
+ك
+ن
+ش
+ه
+ا
+4
+9
+ی
+ج
+ِ
+7
+غ
+ل
+س
+ز
+ّ
+ت
+ک
+گ
+ي
+م
+ب
+ف
+چ
+خ
+ق
+ژ
+آ
+ص
+پ
+َ
+ع
+ئ
+ح
+ٔ
+ض
+ُ
+ذ
+أ
+ى
+ط
+ظ
+ث
+ة
+ً
+ء
+ؤ
+ْ
+ۀ
+إ
+ٍ
+ٌ
+ٰ
+ٓ
+ٱ
+s
+c
+e
+n
+w
+N
+E
+W
+Y
+D
+O
+H
+A
+d
+z
+r
+T
+G
+o
+t
+x
+h
+b
+B
+M
+Z
+u
+P
+F
+y
+q
+U
+K
+k
+J
+Q
+'
+X
+#
+?
+%
+$
+,
+:
+&
+!
+-
+(
+É
+@
+é
++
+
diff --git a/tools/utils/dict/french_dict.txt b/tools/utils/dict/french_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e8f657db35bf0b74f779f38a9e3b9b47b007e3c4
--- /dev/null
+++ b/tools/utils/dict/french_dict.txt
@@ -0,0 +1,136 @@
+f
+e
+n
+c
+h
+_
+i
+m
+g
+/
+r
+v
+a
+l
+t
+w
+o
+d
+6
+1
+.
+p
+B
+u
+2
+à
+3
+R
+y
+4
+U
+E
+A
+5
+P
+O
+S
+T
+D
+7
+Z
+8
+I
+N
+L
+G
+M
+H
+0
+J
+K
+-
+9
+F
+C
+V
+é
+X
+'
+s
+Q
+:
+è
+x
+b
+Y
+Œ
+É
+z
+W
+Ç
+È
+k
+Ô
+ô
+€
+À
+Ê
+q
+ù
+°
+ê
+î
+*
+Â
+j
+"
+,
+â
+%
+û
+ç
+ü
+?
+!
+;
+ö
+(
+)
+ï
+º
+ó
+ø
+å
++
+™
+á
+Ë
+<
+²
+Á
+Î
+&
+@
+œ
+ε
+Ü
+ë
+[
+]
+í
+ò
+Ö
+ä
+ß
+«
+»
+ú
+ñ
+æ
+µ
+³
+Å
+$
+#
+
diff --git a/tools/utils/dict/german_dict.txt b/tools/utils/dict/german_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5e121af21a1617dd970234ca98ae7072b0335332
--- /dev/null
+++ b/tools/utils/dict/german_dict.txt
@@ -0,0 +1,143 @@
+
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+=
+>
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+[
+]
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+£
+§
+
+°
+´
+µ
+·
+º
+¿
+Á
+Ä
+Å
+É
+Ï
+Ô
+Ö
+Ü
+ß
+à
+á
+â
+ã
+ä
+å
+æ
+ç
+è
+é
+ê
+ë
+í
+ï
+ñ
+ò
+ó
+ô
+ö
+ø
+ù
+ú
+û
+ü
+ō
+Š
+Ÿ
+ʒ
+β
+δ
+з
+Ṡ
+‘
+€
+©
+ª
+«
+¬
diff --git a/tools/utils/dict/hi_dict.txt b/tools/utils/dict/hi_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8dfedb5ac483966de40caabe0e95118f88aa5a54
--- /dev/null
+++ b/tools/utils/dict/hi_dict.txt
@@ -0,0 +1,162 @@
+
+!
+#
+$
+%
+&
+'
+(
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+É
+é
+ँ
+ं
+ः
+अ
+आ
+इ
+ई
+उ
+ऊ
+ऋ
+ए
+ऐ
+ऑ
+ओ
+औ
+क
+ख
+ग
+घ
+ङ
+च
+छ
+ज
+झ
+ञ
+ट
+ठ
+ड
+ढ
+ण
+त
+थ
+द
+ध
+न
+प
+फ
+ब
+भ
+म
+य
+र
+ल
+ळ
+व
+श
+ष
+स
+ह
+़
+ा
+ि
+ी
+ु
+ू
+ृ
+ॅ
+े
+ै
+ॉ
+ो
+ौ
+्
+क़
+ख़
+ग़
+ज़
+ड़
+ढ़
+फ़
+०
+१
+२
+३
+४
+५
+६
+७
+८
+९
+॰
diff --git a/tools/utils/dict/it_dict.txt b/tools/utils/dict/it_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e692c6d4335b4b8b2ed873d7923a69ed5e3d6c9a
--- /dev/null
+++ b/tools/utils/dict/it_dict.txt
@@ -0,0 +1,118 @@
+i
+t
+_
+m
+g
+/
+5
+I
+L
+S
+V
+R
+C
+2
+0
+1
+v
+a
+l
+7
+8
+9
+6
+.
+j
+p
+
+e
+r
+o
+d
+s
+n
+3
+4
+P
+u
+c
+A
+-
+,
+"
+z
+h
+f
+b
+q
+ì
+'
+à
+O
+è
+G
+ù
+é
+ò
+;
+F
+E
+B
+N
+H
+k
+:
+U
+T
+X
+D
+K
+?
+[
+M
+
+x
+y
+(
+)
+W
+ö
+º
+w
+]
+Q
+J
++
+ü
+!
+È
+á
+%
+=
+»
+ñ
+Ö
+Y
+ä
+í
+Z
+«
+@
+ó
+ø
+ï
+ú
+ê
+ç
+Á
+É
+Å
+ß
+{
+}
+&
+`
+û
+î
+#
+$
diff --git a/tools/utils/dict/japan_dict.txt b/tools/utils/dict/japan_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..339d4b89e5159a346636641a0814874faa59754a
--- /dev/null
+++ b/tools/utils/dict/japan_dict.txt
@@ -0,0 +1,4399 @@
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+[
+]
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+©
+°
+²
+´
+½
+Á
+Ä
+Å
+Ç
+È
+É
+Í
+Ó
+Ö
+×
+Ü
+ß
+à
+á
+â
+ã
+ä
+å
+æ
+ç
+è
+é
+ê
+ë
+í
+ð
+ñ
+ò
+ó
+ô
+õ
+ö
+ø
+ú
+û
+ü
+ý
+ā
+ă
+ą
+ć
+Č
+č
+đ
+ē
+ė
+ę
+ğ
+ī
+ı
+Ł
+ł
+ń
+ň
+ō
+ř
+Ş
+ş
+Š
+š
+ţ
+ū
+ż
+Ž
+ž
+Ș
+ș
+ț
+Δ
+α
+λ
+μ
+φ
+Г
+О
+а
+в
+л
+о
+р
+с
+т
+я
+ồ
+
+—
+―
+’
+“
+”
+…
+℃
+→
+∇
+−
+■
+☆
+
+、
+。
+々
+〆
+〈
+〉
+「
+」
+『
+』
+〔
+〕
+〜
+ぁ
+あ
+ぃ
+い
+う
+ぇ
+え
+ぉ
+お
+か
+が
+き
+ぎ
+く
+ぐ
+け
+げ
+こ
+ご
+さ
+ざ
+し
+じ
+す
+ず
+せ
+ぜ
+そ
+ぞ
+た
+だ
+ち
+ぢ
+っ
+つ
+づ
+て
+で
+と
+ど
+な
+に
+ぬ
+ね
+の
+は
+ば
+ぱ
+ひ
+び
+ぴ
+ふ
+ぶ
+ぷ
+へ
+べ
+ぺ
+ほ
+ぼ
+ぽ
+ま
+み
+む
+め
+も
+ゃ
+や
+ゅ
+ゆ
+ょ
+よ
+ら
+り
+る
+れ
+ろ
+わ
+ゑ
+を
+ん
+ゝ
+ゞ
+ァ
+ア
+ィ
+イ
+ゥ
+ウ
+ェ
+エ
+ォ
+オ
+カ
+ガ
+キ
+ギ
+ク
+グ
+ケ
+ゲ
+コ
+ゴ
+サ
+ザ
+シ
+ジ
+ス
+ズ
+セ
+ゼ
+ソ
+ゾ
+タ
+ダ
+チ
+ヂ
+ッ
+ツ
+ヅ
+テ
+デ
+ト
+ド
+ナ
+ニ
+ヌ
+ネ
+ノ
+ハ
+バ
+パ
+ヒ
+ビ
+ピ
+フ
+ブ
+プ
+ヘ
+ベ
+ペ
+ホ
+ボ
+ポ
+マ
+ミ
+ム
+メ
+モ
+ャ
+ヤ
+ュ
+ユ
+ョ
+ヨ
+ラ
+リ
+ル
+レ
+ロ
+ワ
+ヰ
+ン
+ヴ
+ヵ
+ヶ
+・
+ー
+㈱
+一
+丁
+七
+万
+丈
+三
+上
+下
+不
+与
+丑
+且
+世
+丘
+丙
+丞
+両
+並
+中
+串
+丸
+丹
+主
+丼
+丿
+乃
+久
+之
+乎
+乏
+乗
+乘
+乙
+九
+乞
+也
+乱
+乳
+乾
+亀
+了
+予
+争
+事
+二
+于
+互
+五
+井
+亘
+亙
+些
+亜
+亟
+亡
+交
+亥
+亦
+亨
+享
+京
+亭
+亮
+人
+什
+仁
+仇
+今
+介
+仍
+仏
+仔
+仕
+他
+仗
+付
+仙
+代
+令
+以
+仮
+仰
+仲
+件
+任
+企
+伊
+伍
+伎
+伏
+伐
+休
+会
+伝
+伯
+估
+伴
+伶
+伸
+伺
+似
+伽
+佃
+但
+位
+低
+住
+佐
+佑
+体
+何
+余
+佚
+佛
+作
+佩
+佳
+併
+佶
+使
+侈
+例
+侍
+侏
+侑
+侘
+供
+依
+侠
+価
+侮
+侯
+侵
+侶
+便
+係
+促
+俄
+俊
+俔
+俗
+俘
+保
+信
+俣
+俤
+修
+俯
+俳
+俵
+俸
+俺
+倉
+個
+倍
+倒
+候
+借
+倣
+値
+倫
+倭
+倶
+倹
+偃
+假
+偈
+偉
+偏
+偐
+偕
+停
+健
+側
+偵
+偶
+偽
+傀
+傅
+傍
+傑
+傘
+備
+催
+傭
+傲
+傳
+債
+傷
+傾
+僊
+働
+像
+僑
+僕
+僚
+僧
+僭
+僮
+儀
+億
+儇
+儒
+儛
+償
+儡
+優
+儲
+儺
+儼
+兀
+允
+元
+兄
+充
+兆
+先
+光
+克
+兌
+免
+兎
+児
+党
+兜
+入
+全
+八
+公
+六
+共
+兵
+其
+具
+典
+兼
+内
+円
+冊
+再
+冑
+冒
+冗
+写
+冠
+冤
+冥
+冨
+冬
+冲
+决
+冶
+冷
+准
+凉
+凋
+凌
+凍
+凛
+凝
+凞
+几
+凡
+処
+凪
+凰
+凱
+凶
+凸
+凹
+出
+函
+刀
+刃
+分
+切
+刈
+刊
+刎
+刑
+列
+初
+判
+別
+利
+刪
+到
+制
+刷
+券
+刹
+刺
+刻
+剃
+則
+削
+剋
+前
+剖
+剛
+剣
+剤
+剥
+剪
+副
+剰
+割
+創
+剽
+劇
+劉
+劔
+力
+功
+加
+劣
+助
+努
+劫
+劭
+励
+労
+効
+劾
+勃
+勅
+勇
+勉
+勒
+動
+勘
+務
+勝
+募
+勢
+勤
+勧
+勲
+勺
+勾
+勿
+匁
+匂
+包
+匏
+化
+北
+匙
+匝
+匠
+匡
+匣
+匯
+匲
+匹
+区
+医
+匿
+十
+千
+升
+午
+卉
+半
+卍
+卑
+卒
+卓
+協
+南
+単
+博
+卜
+占
+卦
+卯
+印
+危
+即
+却
+卵
+卸
+卿
+厄
+厚
+原
+厠
+厨
+厩
+厭
+厳
+去
+参
+又
+叉
+及
+友
+双
+反
+収
+叔
+取
+受
+叙
+叛
+叟
+叡
+叢
+口
+古
+句
+叩
+只
+叫
+召
+可
+台
+叱
+史
+右
+叶
+号
+司
+吃
+各
+合
+吉
+吊
+同
+名
+后
+吏
+吐
+向
+君
+吝
+吟
+吠
+否
+含
+吸
+吹
+吻
+吽
+吾
+呂
+呆
+呈
+呉
+告
+呑
+周
+呪
+呰
+味
+呼
+命
+咀
+咄
+咋
+和
+咒
+咫
+咲
+咳
+咸
+哀
+品
+哇
+哉
+員
+哨
+哩
+哭
+哲
+哺
+唄
+唆
+唇
+唐
+唖
+唯
+唱
+唳
+唸
+唾
+啄
+商
+問
+啓
+啼
+善
+喋
+喚
+喜
+喝
+喧
+喩
+喪
+喫
+喬
+單
+喰
+営
+嗅
+嗇
+嗔
+嗚
+嗜
+嗣
+嘆
+嘉
+嘗
+嘘
+嘩
+嘯
+嘱
+嘲
+嘴
+噂
+噌
+噛
+器
+噴
+噺
+嚆
+嚢
+囀
+囃
+囉
+囚
+四
+回
+因
+団
+困
+囲
+図
+固
+国
+圀
+圃
+國
+圏
+園
+圓
+團
+圜
+土
+圧
+在
+圭
+地
+址
+坂
+均
+坊
+坐
+坑
+坡
+坤
+坦
+坪
+垂
+型
+垢
+垣
+埃
+埋
+城
+埒
+埔
+域
+埠
+埴
+埵
+執
+培
+基
+埼
+堀
+堂
+堅
+堆
+堕
+堤
+堪
+堯
+堰
+報
+場
+堵
+堺
+塀
+塁
+塊
+塑
+塔
+塗
+塘
+塙
+塚
+塞
+塩
+填
+塵
+塾
+境
+墉
+墓
+増
+墜
+墟
+墨
+墳
+墺
+墻
+墾
+壁
+壇
+壊
+壌
+壕
+士
+壬
+壮
+声
+壱
+売
+壷
+壹
+壺
+壽
+変
+夏
+夕
+外
+夙
+多
+夜
+夢
+夥
+大
+天
+太
+夫
+夬
+夭
+央
+失
+夷
+夾
+奄
+奇
+奈
+奉
+奎
+奏
+契
+奔
+奕
+套
+奘
+奠
+奢
+奥
+奨
+奪
+奮
+女
+奴
+奸
+好
+如
+妃
+妄
+妊
+妍
+妓
+妖
+妙
+妥
+妨
+妬
+妲
+妹
+妻
+妾
+姉
+始
+姐
+姓
+委
+姚
+姜
+姞
+姥
+姦
+姨
+姪
+姫
+姶
+姻
+姿
+威
+娑
+娘
+娟
+娠
+娩
+娯
+娼
+婆
+婉
+婚
+婢
+婦
+婬
+婿
+媄
+媒
+媓
+媚
+媛
+媞
+媽
+嫁
+嫄
+嫉
+嫌
+嫐
+嫗
+嫡
+嬉
+嬌
+嬢
+嬪
+嬬
+嬾
+孁
+子
+孔
+字
+存
+孚
+孝
+孟
+季
+孤
+学
+孫
+孵
+學
+宅
+宇
+守
+安
+宋
+完
+宍
+宏
+宕
+宗
+官
+宙
+定
+宛
+宜
+宝
+実
+客
+宣
+室
+宥
+宮
+宰
+害
+宴
+宵
+家
+宸
+容
+宿
+寂
+寄
+寅
+密
+寇
+富
+寒
+寓
+寔
+寛
+寝
+察
+寡
+實
+寧
+審
+寮
+寵
+寶
+寸
+寺
+対
+寿
+封
+専
+射
+将
+尉
+尊
+尋
+對
+導
+小
+少
+尖
+尚
+尤
+尪
+尭
+就
+尹
+尺
+尻
+尼
+尽
+尾
+尿
+局
+居
+屈
+届
+屋
+屍
+屎
+屏
+屑
+屓
+展
+属
+屠
+層
+履
+屯
+山
+岐
+岑
+岡
+岩
+岫
+岬
+岳
+岷
+岸
+峠
+峡
+峨
+峯
+峰
+島
+峻
+崇
+崋
+崎
+崑
+崖
+崗
+崛
+崩
+嵌
+嵐
+嵩
+嵯
+嶂
+嶋
+嶠
+嶺
+嶼
+嶽
+巀
+巌
+巒
+巖
+川
+州
+巡
+巣
+工
+左
+巧
+巨
+巫
+差
+己
+巳
+巴
+巷
+巻
+巽
+巾
+市
+布
+帆
+希
+帖
+帚
+帛
+帝
+帥
+師
+席
+帯
+帰
+帳
+帷
+常
+帽
+幄
+幅
+幇
+幌
+幔
+幕
+幟
+幡
+幢
+幣
+干
+平
+年
+并
+幸
+幹
+幻
+幼
+幽
+幾
+庁
+広
+庄
+庇
+床
+序
+底
+庖
+店
+庚
+府
+度
+座
+庫
+庭
+庵
+庶
+康
+庸
+廂
+廃
+廉
+廊
+廓
+廟
+廠
+廣
+廬
+延
+廷
+建
+廻
+廼
+廿
+弁
+弄
+弉
+弊
+弌
+式
+弐
+弓
+弔
+引
+弖
+弗
+弘
+弛
+弟
+弥
+弦
+弧
+弱
+張
+強
+弼
+弾
+彈
+彊
+彌
+彎
+当
+彗
+彙
+彝
+形
+彦
+彩
+彫
+彬
+彭
+彰
+影
+彷
+役
+彼
+往
+征
+徂
+径
+待
+律
+後
+徐
+徑
+徒
+従
+得
+徠
+御
+徧
+徨
+復
+循
+徭
+微
+徳
+徴
+德
+徹
+徽
+心
+必
+忉
+忌
+忍
+志
+忘
+忙
+応
+忠
+快
+忯
+念
+忻
+忽
+忿
+怒
+怖
+思
+怠
+怡
+急
+性
+怨
+怪
+怯
+恂
+恋
+恐
+恒
+恕
+恣
+恤
+恥
+恨
+恩
+恬
+恭
+息
+恵
+悉
+悌
+悍
+悔
+悟
+悠
+患
+悦
+悩
+悪
+悲
+悼
+情
+惇
+惑
+惚
+惜
+惟
+惠
+惣
+惧
+惨
+惰
+想
+惹
+惺
+愈
+愉
+愍
+意
+愔
+愚
+愛
+感
+愷
+愿
+慈
+態
+慌
+慎
+慕
+慢
+慣
+慧
+慨
+慮
+慰
+慶
+憂
+憎
+憐
+憑
+憙
+憤
+憧
+憩
+憬
+憲
+憶
+憾
+懇
+應
+懌
+懐
+懲
+懸
+懺
+懽
+懿
+戈
+戊
+戌
+戎
+成
+我
+戒
+戔
+或
+戚
+戟
+戦
+截
+戮
+戯
+戴
+戸
+戻
+房
+所
+扁
+扇
+扈
+扉
+手
+才
+打
+払
+托
+扮
+扱
+扶
+批
+承
+技
+抄
+把
+抑
+抓
+投
+抗
+折
+抜
+択
+披
+抱
+抵
+抹
+押
+抽
+担
+拇
+拈
+拉
+拍
+拏
+拐
+拒
+拓
+拘
+拙
+招
+拝
+拠
+拡
+括
+拭
+拳
+拵
+拶
+拾
+拿
+持
+挂
+指
+按
+挑
+挙
+挟
+挨
+振
+挺
+挽
+挿
+捉
+捕
+捗
+捜
+捧
+捨
+据
+捺
+捻
+掃
+掄
+授
+掌
+排
+掖
+掘
+掛
+掟
+採
+探
+掣
+接
+控
+推
+掩
+措
+掬
+掲
+掴
+掻
+掾
+揃
+揄
+揆
+揉
+描
+提
+揖
+揚
+換
+握
+揮
+援
+揶
+揺
+損
+搦
+搬
+搭
+携
+搾
+摂
+摘
+摩
+摸
+摺
+撃
+撒
+撞
+撤
+撥
+撫
+播
+撮
+撰
+撲
+撹
+擁
+操
+擔
+擦
+擬
+擾
+攘
+攝
+攣
+支
+收
+改
+攻
+放
+政
+故
+敏
+救
+敗
+教
+敢
+散
+敦
+敬
+数
+整
+敵
+敷
+斂
+文
+斉
+斎
+斐
+斑
+斗
+料
+斜
+斟
+斤
+斥
+斧
+斬
+断
+斯
+新
+方
+於
+施
+旁
+旅
+旋
+旌
+族
+旗
+旛
+无
+旡
+既
+日
+旦
+旧
+旨
+早
+旬
+旭
+旺
+旻
+昂
+昆
+昇
+昉
+昌
+明
+昏
+易
+昔
+星
+映
+春
+昧
+昨
+昪
+昭
+是
+昵
+昼
+晁
+時
+晃
+晋
+晏
+晒
+晟
+晦
+晧
+晩
+普
+景
+晴
+晶
+智
+暁
+暇
+暈
+暉
+暑
+暖
+暗
+暘
+暢
+暦
+暫
+暮
+暲
+暴
+暹
+暾
+曄
+曇
+曉
+曖
+曙
+曜
+曝
+曠
+曰
+曲
+曳
+更
+書
+曹
+曼
+曽
+曾
+替
+最
+會
+月
+有
+朋
+服
+朏
+朔
+朕
+朗
+望
+朝
+期
+朧
+木
+未
+末
+本
+札
+朱
+朴
+机
+朽
+杁
+杉
+李
+杏
+材
+村
+杓
+杖
+杜
+杞
+束
+条
+杢
+杣
+来
+杭
+杮
+杯
+東
+杲
+杵
+杷
+杼
+松
+板
+枅
+枇
+析
+枓
+枕
+林
+枚
+果
+枝
+枠
+枡
+枢
+枯
+枳
+架
+柄
+柊
+柏
+某
+柑
+染
+柔
+柘
+柚
+柯
+柱
+柳
+柴
+柵
+査
+柾
+柿
+栂
+栃
+栄
+栖
+栗
+校
+株
+栲
+栴
+核
+根
+栻
+格
+栽
+桁
+桂
+桃
+框
+案
+桐
+桑
+桓
+桔
+桜
+桝
+桟
+桧
+桴
+桶
+桾
+梁
+梅
+梆
+梓
+梔
+梗
+梛
+條
+梟
+梢
+梧
+梨
+械
+梱
+梲
+梵
+梶
+棄
+棋
+棒
+棗
+棘
+棚
+棟
+棠
+森
+棲
+棹
+棺
+椀
+椅
+椋
+植
+椎
+椏
+椒
+椙
+検
+椥
+椹
+椿
+楊
+楓
+楕
+楚
+楞
+楠
+楡
+楢
+楨
+楪
+楫
+業
+楮
+楯
+楳
+極
+楷
+楼
+楽
+概
+榊
+榎
+榕
+榛
+榜
+榮
+榱
+榴
+槃
+槇
+槊
+構
+槌
+槍
+槐
+様
+槙
+槻
+槽
+槿
+樂
+樋
+樓
+樗
+標
+樟
+模
+権
+横
+樫
+樵
+樹
+樺
+樽
+橇
+橋
+橘
+機
+橿
+檀
+檄
+檎
+檐
+檗
+檜
+檣
+檥
+檬
+檮
+檸
+檻
+櫃
+櫓
+櫛
+櫟
+櫨
+櫻
+欄
+欅
+欠
+次
+欣
+欧
+欲
+欺
+欽
+款
+歌
+歎
+歓
+止
+正
+此
+武
+歩
+歪
+歯
+歳
+歴
+死
+殆
+殉
+殊
+残
+殖
+殯
+殴
+段
+殷
+殺
+殻
+殿
+毀
+毅
+母
+毎
+毒
+比
+毘
+毛
+毫
+毬
+氈
+氏
+民
+気
+水
+氷
+永
+氾
+汀
+汁
+求
+汎
+汐
+汗
+汚
+汝
+江
+池
+汪
+汰
+汲
+決
+汽
+沂
+沃
+沅
+沆
+沈
+沌
+沐
+沓
+沖
+沙
+没
+沢
+沱
+河
+沸
+油
+治
+沼
+沽
+沿
+況
+泉
+泊
+泌
+法
+泗
+泡
+波
+泣
+泥
+注
+泯
+泰
+泳
+洋
+洒
+洗
+洛
+洞
+津
+洩
+洪
+洲
+洸
+洹
+活
+洽
+派
+流
+浄
+浅
+浙
+浚
+浜
+浣
+浦
+浩
+浪
+浮
+浴
+海
+浸
+涅
+消
+涌
+涙
+涛
+涯
+液
+涵
+涼
+淀
+淄
+淆
+淇
+淋
+淑
+淘
+淡
+淤
+淨
+淫
+深
+淳
+淵
+混
+淹
+添
+清
+済
+渉
+渋
+渓
+渕
+渚
+減
+渟
+渠
+渡
+渤
+渥
+渦
+温
+渫
+測
+港
+游
+渾
+湊
+湖
+湘
+湛
+湧
+湫
+湯
+湾
+湿
+満
+源
+準
+溜
+溝
+溢
+溥
+溪
+溶
+溺
+滄
+滅
+滋
+滌
+滑
+滕
+滝
+滞
+滴
+滸
+滹
+滿
+漁
+漂
+漆
+漉
+漏
+漑
+演
+漕
+漠
+漢
+漣
+漫
+漬
+漱
+漸
+漿
+潅
+潔
+潙
+潜
+潟
+潤
+潭
+潮
+潰
+潴
+澁
+澂
+澄
+澎
+澗
+澤
+澪
+澱
+澳
+激
+濁
+濃
+濟
+濠
+濡
+濤
+濫
+濯
+濱
+濾
+瀉
+瀋
+瀑
+瀕
+瀞
+瀟
+瀧
+瀬
+瀾
+灌
+灑
+灘
+火
+灯
+灰
+灸
+災
+炉
+炊
+炎
+炒
+炭
+炮
+炷
+点
+為
+烈
+烏
+烙
+烝
+烹
+焔
+焙
+焚
+無
+焦
+然
+焼
+煇
+煉
+煌
+煎
+煕
+煙
+煤
+煥
+照
+煩
+煬
+煮
+煽
+熈
+熊
+熙
+熟
+熨
+熱
+熹
+熾
+燃
+燈
+燎
+燔
+燕
+燗
+燥
+燭
+燻
+爆
+爐
+爪
+爬
+爲
+爵
+父
+爺
+爼
+爽
+爾
+片
+版
+牌
+牒
+牘
+牙
+牛
+牝
+牟
+牡
+牢
+牧
+物
+牲
+特
+牽
+犂
+犠
+犬
+犯
+状
+狂
+狄
+狐
+狗
+狙
+狛
+狡
+狩
+独
+狭
+狷
+狸
+狼
+猊
+猛
+猟
+猥
+猨
+猩
+猪
+猫
+献
+猴
+猶
+猷
+猾
+猿
+獄
+獅
+獏
+獣
+獲
+玄
+玅
+率
+玉
+王
+玖
+玩
+玲
+珀
+珂
+珈
+珉
+珊
+珍
+珎
+珞
+珠
+珣
+珥
+珪
+班
+現
+球
+理
+琉
+琢
+琥
+琦
+琮
+琲
+琳
+琴
+琵
+琶
+瑁
+瑋
+瑙
+瑚
+瑛
+瑜
+瑞
+瑠
+瑤
+瑩
+瑪
+瑳
+瑾
+璃
+璋
+璜
+璞
+璧
+璨
+環
+璵
+璽
+璿
+瓊
+瓔
+瓜
+瓢
+瓦
+瓶
+甍
+甑
+甕
+甘
+甚
+甞
+生
+産
+甥
+用
+甫
+田
+由
+甲
+申
+男
+町
+画
+界
+畏
+畑
+畔
+留
+畜
+畝
+畠
+畢
+略
+番
+異
+畳
+當
+畷
+畸
+畺
+畿
+疆
+疇
+疋
+疎
+疏
+疑
+疫
+疱
+疲
+疹
+疼
+疾
+病
+症
+痒
+痔
+痕
+痘
+痙
+痛
+痢
+痩
+痴
+痺
+瘍
+瘡
+瘧
+療
+癇
+癌
+癒
+癖
+癡
+癪
+発
+登
+白
+百
+的
+皆
+皇
+皋
+皐
+皓
+皮
+皺
+皿
+盂
+盃
+盆
+盈
+益
+盒
+盗
+盛
+盞
+盟
+盡
+監
+盤
+盥
+盧
+目
+盲
+直
+相
+盾
+省
+眉
+看
+県
+眞
+真
+眠
+眷
+眺
+眼
+着
+睡
+督
+睦
+睨
+睿
+瞋
+瞑
+瞞
+瞬
+瞭
+瞰
+瞳
+瞻
+瞼
+瞿
+矍
+矛
+矜
+矢
+知
+矧
+矩
+短
+矮
+矯
+石
+砂
+砌
+研
+砕
+砥
+砦
+砧
+砲
+破
+砺
+硝
+硫
+硬
+硯
+碁
+碇
+碌
+碑
+碓
+碕
+碗
+碣
+碧
+碩
+確
+碾
+磁
+磐
+磔
+磧
+磨
+磬
+磯
+礁
+礎
+礒
+礙
+礫
+礬
+示
+礼
+社
+祀
+祁
+祇
+祈
+祉
+祐
+祓
+祕
+祖
+祗
+祚
+祝
+神
+祟
+祠
+祢
+祥
+票
+祭
+祷
+祺
+禁
+禄
+禅
+禊
+禍
+禎
+福
+禔
+禖
+禛
+禦
+禧
+禮
+禰
+禹
+禽
+禿
+秀
+私
+秋
+科
+秒
+秘
+租
+秤
+秦
+秩
+称
+移
+稀
+程
+税
+稔
+稗
+稙
+稚
+稜
+稠
+種
+稱
+稲
+稷
+稻
+稼
+稽
+稿
+穀
+穂
+穆
+積
+穎
+穏
+穗
+穜
+穢
+穣
+穫
+穴
+究
+空
+突
+窃
+窄
+窒
+窓
+窟
+窠
+窩
+窪
+窮
+窯
+竃
+竄
+竈
+立
+站
+竜
+竝
+竟
+章
+童
+竪
+竭
+端
+竴
+競
+竹
+竺
+竽
+竿
+笄
+笈
+笏
+笑
+笙
+笛
+笞
+笠
+笥
+符
+第
+笹
+筅
+筆
+筇
+筈
+等
+筋
+筌
+筍
+筏
+筐
+筑
+筒
+答
+策
+筝
+筥
+筧
+筬
+筮
+筯
+筰
+筵
+箆
+箇
+箋
+箏
+箒
+箔
+箕
+算
+箙
+箜
+管
+箪
+箭
+箱
+箸
+節
+篁
+範
+篆
+篇
+築
+篋
+篌
+篝
+篠
+篤
+篥
+篦
+篩
+篭
+篳
+篷
+簀
+簒
+簡
+簧
+簪
+簫
+簺
+簾
+簿
+籀
+籃
+籌
+籍
+籐
+籟
+籠
+籤
+籬
+米
+籾
+粂
+粉
+粋
+粒
+粕
+粗
+粘
+粛
+粟
+粥
+粧
+粮
+粳
+精
+糊
+糖
+糜
+糞
+糟
+糠
+糧
+糯
+糸
+糺
+系
+糾
+紀
+約
+紅
+紋
+納
+紐
+純
+紗
+紘
+紙
+級
+紛
+素
+紡
+索
+紫
+紬
+累
+細
+紳
+紵
+紹
+紺
+絁
+終
+絃
+組
+絅
+経
+結
+絖
+絞
+絡
+絣
+給
+統
+絲
+絵
+絶
+絹
+絽
+綏
+經
+継
+続
+綜
+綟
+綬
+維
+綱
+網
+綴
+綸
+綺
+綽
+綾
+綿
+緊
+緋
+総
+緑
+緒
+線
+締
+緥
+編
+緩
+緬
+緯
+練
+緻
+縁
+縄
+縅
+縒
+縛
+縞
+縢
+縣
+縦
+縫
+縮
+縹
+總
+績
+繁
+繊
+繋
+繍
+織
+繕
+繝
+繦
+繧
+繰
+繹
+繼
+纂
+纈
+纏
+纐
+纒
+纛
+缶
+罔
+罠
+罧
+罪
+置
+罰
+署
+罵
+罷
+罹
+羂
+羅
+羆
+羇
+羈
+羊
+羌
+美
+群
+羨
+義
+羯
+羲
+羹
+羽
+翁
+翅
+翌
+習
+翔
+翛
+翠
+翡
+翫
+翰
+翺
+翻
+翼
+耀
+老
+考
+者
+耆
+而
+耐
+耕
+耗
+耨
+耳
+耶
+耽
+聊
+聖
+聘
+聚
+聞
+聟
+聡
+聨
+聯
+聰
+聲
+聴
+職
+聾
+肄
+肆
+肇
+肉
+肋
+肌
+肖
+肘
+肛
+肝
+股
+肢
+肥
+肩
+肪
+肯
+肱
+育
+肴
+肺
+胃
+胆
+背
+胎
+胖
+胚
+胝
+胞
+胡
+胤
+胱
+胴
+胸
+能
+脂
+脅
+脆
+脇
+脈
+脊
+脚
+脛
+脩
+脱
+脳
+腋
+腎
+腐
+腑
+腔
+腕
+腫
+腰
+腱
+腸
+腹
+腺
+腿
+膀
+膏
+膚
+膜
+膝
+膠
+膣
+膨
+膩
+膳
+膵
+膾
+膿
+臂
+臆
+臈
+臍
+臓
+臘
+臚
+臣
+臥
+臨
+自
+臭
+至
+致
+臺
+臼
+舂
+舅
+與
+興
+舌
+舍
+舎
+舒
+舖
+舗
+舘
+舜
+舞
+舟
+舩
+航
+般
+舳
+舶
+船
+艇
+艘
+艦
+艮
+良
+色
+艶
+芋
+芒
+芙
+芝
+芥
+芦
+芬
+芭
+芯
+花
+芳
+芸
+芹
+芻
+芽
+芿
+苅
+苑
+苔
+苗
+苛
+苞
+苡
+若
+苦
+苧
+苫
+英
+苴
+苻
+茂
+范
+茄
+茅
+茎
+茗
+茘
+茜
+茨
+茲
+茵
+茶
+茸
+茹
+草
+荊
+荏
+荒
+荘
+荷
+荻
+荼
+莞
+莪
+莫
+莬
+莱
+莵
+莽
+菅
+菊
+菌
+菓
+菖
+菘
+菜
+菟
+菩
+菫
+華
+菱
+菴
+萄
+萊
+萌
+萍
+萎
+萠
+萩
+萬
+萱
+落
+葉
+著
+葛
+葡
+董
+葦
+葩
+葬
+葭
+葱
+葵
+葺
+蒋
+蒐
+蒔
+蒙
+蒟
+蒡
+蒲
+蒸
+蒻
+蒼
+蒿
+蓄
+蓆
+蓉
+蓋
+蓑
+蓬
+蓮
+蓼
+蔀
+蔑
+蔓
+蔚
+蔡
+蔦
+蔬
+蔭
+蔵
+蔽
+蕃
+蕉
+蕊
+蕎
+蕨
+蕩
+蕪
+蕭
+蕾
+薄
+薇
+薊
+薔
+薗
+薙
+薛
+薦
+薨
+薩
+薪
+薫
+薬
+薭
+薮
+藁
+藉
+藍
+藏
+藐
+藝
+藤
+藩
+藪
+藷
+藹
+藺
+藻
+蘂
+蘆
+蘇
+蘊
+蘭
+虎
+虐
+虔
+虚
+虜
+虞
+號
+虫
+虹
+虻
+蚊
+蚕
+蛇
+蛉
+蛍
+蛎
+蛙
+蛛
+蛟
+蛤
+蛭
+蛮
+蛸
+蛹
+蛾
+蜀
+蜂
+蜃
+蜆
+蜊
+蜘
+蜜
+蜷
+蜻
+蝉
+蝋
+蝕
+蝙
+蝠
+蝦
+蝶
+蝿
+螂
+融
+螣
+螺
+蟄
+蟇
+蟠
+蟷
+蟹
+蟻
+蠢
+蠣
+血
+衆
+行
+衍
+衒
+術
+街
+衙
+衛
+衝
+衞
+衡
+衢
+衣
+表
+衫
+衰
+衵
+衷
+衽
+衾
+衿
+袁
+袈
+袋
+袍
+袒
+袖
+袙
+袞
+袢
+被
+袰
+袱
+袴
+袷
+袿
+裁
+裂
+裃
+装
+裏
+裔
+裕
+裘
+裙
+補
+裟
+裡
+裲
+裳
+裴
+裸
+裹
+製
+裾
+褂
+褄
+複
+褌
+褐
+褒
+褥
+褪
+褶
+褻
+襄
+襖
+襞
+襟
+襠
+襦
+襪
+襲
+襴
+襷
+西
+要
+覆
+覇
+覈
+見
+規
+視
+覗
+覚
+覧
+親
+覲
+観
+覺
+觀
+角
+解
+触
+言
+訂
+計
+討
+訓
+託
+記
+訛
+訟
+訢
+訥
+訪
+設
+許
+訳
+訴
+訶
+診
+註
+証
+詐
+詔
+評
+詛
+詞
+詠
+詢
+詣
+試
+詩
+詫
+詮
+詰
+話
+該
+詳
+誄
+誅
+誇
+誉
+誌
+認
+誓
+誕
+誘
+語
+誠
+誡
+誣
+誤
+誥
+誦
+説
+読
+誰
+課
+誼
+誾
+調
+談
+請
+諌
+諍
+諏
+諒
+論
+諚
+諜
+諟
+諡
+諦
+諧
+諫
+諭
+諮
+諱
+諶
+諷
+諸
+諺
+諾
+謀
+謄
+謌
+謎
+謗
+謙
+謚
+講
+謝
+謡
+謫
+謬
+謹
+證
+識
+譚
+譛
+譜
+警
+譬
+譯
+議
+譲
+譴
+護
+讀
+讃
+讐
+讒
+谷
+谿
+豅
+豆
+豊
+豎
+豐
+豚
+象
+豪
+豫
+豹
+貌
+貝
+貞
+負
+財
+貢
+貧
+貨
+販
+貪
+貫
+責
+貯
+貰
+貴
+買
+貸
+費
+貼
+貿
+賀
+賁
+賂
+賃
+賄
+資
+賈
+賊
+賎
+賑
+賓
+賛
+賜
+賞
+賠
+賢
+賣
+賤
+賦
+質
+賭
+購
+賽
+贄
+贅
+贈
+贋
+贔
+贖
+赤
+赦
+走
+赴
+起
+超
+越
+趙
+趣
+足
+趺
+趾
+跋
+跏
+距
+跡
+跨
+跪
+路
+跳
+践
+踊
+踏
+踐
+踞
+踪
+踵
+蹄
+蹉
+蹊
+蹟
+蹲
+蹴
+躅
+躇
+躊
+躍
+躑
+躙
+躪
+身
+躬
+躯
+躰
+車
+軋
+軌
+軍
+軒
+軟
+転
+軸
+軻
+軽
+軾
+較
+載
+輌
+輔
+輜
+輝
+輦
+輩
+輪
+輯
+輸
+輿
+轄
+轍
+轟
+轢
+辛
+辞
+辟
+辥
+辦
+辨
+辰
+辱
+農
+辺
+辻
+込
+迂
+迅
+迎
+近
+返
+迢
+迦
+迪
+迫
+迭
+述
+迷
+迹
+追
+退
+送
+逃
+逅
+逆
+逍
+透
+逐
+逓
+途
+逕
+逗
+這
+通
+逝
+逞
+速
+造
+逢
+連
+逮
+週
+進
+逸
+逼
+遁
+遂
+遅
+遇
+遊
+運
+遍
+過
+遐
+道
+達
+違
+遙
+遜
+遠
+遡
+遣
+遥
+適
+遭
+遮
+遯
+遵
+遷
+選
+遺
+遼
+避
+邀
+邁
+邂
+邃
+還
+邇
+邉
+邊
+邑
+那
+邦
+邨
+邪
+邯
+邵
+邸
+郁
+郊
+郎
+郡
+郢
+部
+郭
+郴
+郵
+郷
+都
+鄂
+鄙
+鄭
+鄰
+鄲
+酉
+酋
+酌
+配
+酎
+酒
+酔
+酢
+酥
+酪
+酬
+酵
+酷
+酸
+醍
+醐
+醒
+醗
+醜
+醤
+醪
+醵
+醸
+采
+釈
+釉
+釋
+里
+重
+野
+量
+釐
+金
+釘
+釜
+針
+釣
+釧
+釿
+鈍
+鈎
+鈐
+鈔
+鈞
+鈦
+鈴
+鈷
+鈸
+鈿
+鉄
+鉇
+鉉
+鉋
+鉛
+鉢
+鉤
+鉦
+鉱
+鉾
+銀
+銃
+銅
+銈
+銑
+銕
+銘
+銚
+銜
+銭
+鋏
+鋒
+鋤
+鋭
+鋲
+鋳
+鋸
+鋺
+鋼
+錆
+錍
+錐
+錘
+錠
+錣
+錦
+錫
+錬
+錯
+録
+錵
+鍋
+鍍
+鍑
+鍔
+鍛
+鍬
+鍮
+鍵
+鍼
+鍾
+鎌
+鎖
+鎗
+鎚
+鎧
+鎬
+鎮
+鎰
+鎹
+鏃
+鏑
+鏡
+鐃
+鐇
+鐐
+鐔
+鐘
+鐙
+鐚
+鐡
+鐵
+鐸
+鑁
+鑊
+鑑
+鑒
+鑚
+鑠
+鑢
+鑰
+鑵
+鑷
+鑼
+鑽
+鑿
+長
+門
+閃
+閇
+閉
+開
+閏
+閑
+間
+閔
+閘
+関
+閣
+閤
+閥
+閦
+閨
+閬
+閲
+閻
+閼
+閾
+闇
+闍
+闔
+闕
+闘
+關
+闡
+闢
+闥
+阜
+阪
+阮
+阯
+防
+阻
+阿
+陀
+陂
+附
+陌
+降
+限
+陛
+陞
+院
+陣
+除
+陥
+陪
+陬
+陰
+陳
+陵
+陶
+陸
+険
+陽
+隅
+隆
+隈
+隊
+隋
+階
+随
+隔
+際
+障
+隠
+隣
+隧
+隷
+隻
+隼
+雀
+雁
+雄
+雅
+集
+雇
+雉
+雊
+雋
+雌
+雍
+雑
+雖
+雙
+雛
+離
+難
+雨
+雪
+雫
+雰
+雲
+零
+雷
+雹
+電
+需
+震
+霊
+霍
+霖
+霜
+霞
+霧
+霰
+露
+靈
+青
+靖
+静
+靜
+非
+面
+革
+靫
+靭
+靱
+靴
+靺
+鞁
+鞄
+鞆
+鞋
+鞍
+鞏
+鞘
+鞠
+鞨
+鞭
+韋
+韓
+韜
+韮
+音
+韶
+韻
+響
+頁
+頂
+頃
+項
+順
+須
+頌
+預
+頑
+頒
+頓
+領
+頚
+頬
+頭
+頴
+頸
+頻
+頼
+顆
+題
+額
+顎
+顔
+顕
+顗
+願
+顛
+類
+顧
+顯
+風
+飛
+食
+飢
+飩
+飫
+飯
+飲
+飴
+飼
+飽
+飾
+餃
+餅
+餉
+養
+餌
+餐
+餓
+餘
+餝
+餡
+館
+饂
+饅
+饉
+饋
+饌
+饒
+饗
+首
+馗
+香
+馨
+馬
+馳
+馴
+駄
+駅
+駆
+駈
+駐
+駒
+駕
+駝
+駿
+騁
+騎
+騏
+騒
+験
+騙
+騨
+騰
+驕
+驚
+驛
+驢
+骨
+骸
+髄
+體
+高
+髙
+髢
+髪
+髭
+髮
+髷
+髻
+鬘
+鬚
+鬢
+鬨
+鬯
+鬱
+鬼
+魁
+魂
+魄
+魅
+魏
+魔
+魚
+魯
+鮎
+鮑
+鮒
+鮪
+鮫
+鮭
+鮮
+鯉
+鯔
+鯖
+鯛
+鯨
+鯰
+鯱
+鰐
+鰒
+鰭
+鰯
+鰰
+鰹
+鰻
+鱈
+鱒
+鱗
+鱧
+鳥
+鳩
+鳰
+鳳
+鳴
+鳶
+鴈
+鴉
+鴎
+鴛
+鴟
+鴦
+鴨
+鴫
+鴻
+鵄
+鵜
+鵞
+鵡
+鵬
+鵲
+鵺
+鶉
+鶏
+鶯
+鶴
+鷄
+鷙
+鷲
+鷹
+鷺
+鸚
+鸞
+鹸
+鹽
+鹿
+麁
+麒
+麓
+麗
+麝
+麞
+麟
+麦
+麩
+麹
+麺
+麻
+麾
+麿
+黄
+黌
+黍
+黒
+黙
+黛
+黠
+鼈
+鼉
+鼎
+鼓
+鼠
+鼻
+齊
+齋
+齟
+齢
+齬
+龍
+龕
+龗
+!
+#
+%
+&
+(
+)
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+=
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+R
+S
+T
+U
+V
+W
+X
+Z
+a
+c
+d
+e
+f
+h
+i
+j
+k
+l
+m
+n
+o
+p
+r
+s
+t
+u
+y
+z
+~
+・
+
diff --git a/tools/utils/dict/ka_dict.txt b/tools/utils/dict/ka_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d506b691bd1a6c55299ad89a72cf3a69a2c879a9
--- /dev/null
+++ b/tools/utils/dict/ka_dict.txt
@@ -0,0 +1,153 @@
+k
+a
+_
+i
+m
+g
+/
+1
+2
+I
+L
+S
+V
+R
+C
+0
+v
+l
+6
+4
+8
+.
+j
+p
+ಗ
+ು
+ಣ
+ಪ
+ಡ
+ಿ
+ಸ
+ಲ
+ಾ
+ದ
+್
+7
+5
+3
+ವ
+ಷ
+ಬ
+ಹ
+ೆ
+9
+ಅ
+ಳ
+ನ
+ರ
+ಉ
+ಕ
+ಎ
+ೇ
+ಂ
+ೈ
+ೊ
+ೀ
+ಯ
+ೋ
+ತ
+ಶ
+ಭ
+ಧ
+ಚ
+ಜ
+ೂ
+ಮ
+ಒ
+ೃ
+ಥ
+ಇ
+ಟ
+ಖ
+ಆ
+ಞ
+ಫ
+-
+ಢ
+ಊ
+ಓ
+ಐ
+ಃ
+ಘ
+ಝ
+ೌ
+ಠ
+ಛ
+ಔ
+ಏ
+ಈ
+ಋ
+೨
+೦
+೧
+೮
+೯
+೪
+,
+೫
+೭
+೩
+೬
+ಙ
+s
+c
+e
+n
+w
+o
+u
+t
+d
+E
+A
+T
+B
+Z
+N
+G
+O
+q
+z
+r
+x
+P
+K
+M
+J
+U
+D
+f
+F
+h
+b
+W
+Y
+y
+H
+X
+Q
+'
+#
+&
+!
+@
+$
+:
+%
+é
+É
+(
+?
++
+
diff --git a/tools/utils/dict/kie_dict/xfund_class_list.txt b/tools/utils/dict/kie_dict/xfund_class_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..faded9f9b8f56bd258909bec9b8f1755aa688367
--- /dev/null
+++ b/tools/utils/dict/kie_dict/xfund_class_list.txt
@@ -0,0 +1,4 @@
+OTHER
+QUESTION
+ANSWER
+HEADER
diff --git a/tools/utils/dict/korean_dict.txt b/tools/utils/dict/korean_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a13899f14dfe3bfc25b34904390c7b1e4ed8674b
--- /dev/null
+++ b/tools/utils/dict/korean_dict.txt
@@ -0,0 +1,3688 @@
+!
+"
+#
+$
+%
+&
+'
+*
++
+-
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+[
+\
+]
+^
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+|
+}
+~
+©
+°
+²
+½
+Á
+Ä
+Å
+Ç
+É
+Í
+Î
+Ó
+Ö
+×
+Ü
+ß
+à
+á
+â
+ã
+ä
+å
+æ
+ç
+è
+é
+ê
+ë
+ì
+í
+î
+ï
+ð
+ñ
+ò
+ó
+ô
+õ
+ö
+ø
+ú
+û
+ü
+ý
+ā
+ă
+ą
+ć
+Č
+č
+đ
+ē
+ė
+ę
+ě
+ğ
+ī
+İ
+ı
+Ł
+ł
+ń
+ň
+ō
+ř
+Ş
+ş
+Š
+š
+ţ
+ū
+ź
+ż
+Ž
+ž
+Ș
+ș
+Α
+Δ
+α
+λ
+φ
+Г
+О
+а
+в
+л
+о
+р
+с
+т
+я
+
+’
+“
+”
+→
+∇
+∼
+「
+」
+ア
+カ
+グ
+ニ
+ラ
+ン
+ㄱ
+ㄴ
+ㄷ
+ㄸ
+ㄹ
+ㅂ
+ㅅ
+ㅆ
+ㅇ
+ㅈ
+ㅊ
+ㅋ
+ㅌ
+ㅎ
+ㅓ
+ㅜ
+ㅣ
+一
+丁
+七
+三
+上
+下
+不
+丑
+世
+丘
+丞
+中
+丸
+丹
+主
+乃
+久
+之
+乎
+乘
+九
+也
+乳
+乾
+事
+二
+云
+互
+五
+井
+亞
+亡
+交
+亥
+亨
+享
+京
+亭
+人
+仁
+今
+他
+仙
+代
+令
+以
+仰
+仲
+件
+任
+企
+伊
+伍
+伎
+伏
+伐
+休
+伯
+伴
+伸
+佃
+佈
+位
+低
+住
+佐
+何
+佛
+作
+使
+來
+供
+依
+侯
+侵
+侶
+便
+俗
+保
+俠
+信
+修
+俱
+俳
+倉
+個
+倍
+倒
+候
+借
+値
+倫
+倭
+假
+偈
+偉
+偏
+停
+偶
+傅
+傑
+傳
+傷
+傾
+像
+僞
+僥
+僧
+價
+儀
+儉
+儒
+優
+儼
+兀
+允
+元
+兆
+先
+光
+克
+兒
+入
+內
+全
+八
+公
+六
+共
+兵
+其
+具
+典
+兼
+再
+冠
+冥
+冶
+准
+凞
+凡
+凱
+出
+函
+刀
+分
+刊
+刑
+列
+初
+判
+別
+利
+到
+制
+券
+刺
+刻
+則
+前
+剛
+副
+創
+劃
+劑
+力
+功
+加
+劣
+助
+劫
+勇
+動
+務
+勝
+勢
+勳
+勸
+匈
+化
+北
+匠
+區
+十
+千
+午
+半
+卍
+卑
+卒
+卓
+南
+博
+卜
+占
+卦
+印
+危
+卵
+卷
+卽
+卿
+厄
+原
+厦
+去
+參
+又
+叉
+友
+反
+叔
+受
+口
+古
+句
+可
+台
+史
+右
+司
+各
+合
+吉
+同
+名
+后
+吏
+吐
+君
+吠
+吳
+呂
+告
+周
+味
+呵
+命
+和
+咳
+咸
+咽
+哀
+品
+哨
+哮
+哲
+唐
+唯
+唱
+商
+問
+啼
+善
+喆
+喉
+喜
+喩
+喪
+嘗
+器
+嚴
+囊
+四
+回
+因
+困
+固
+圈
+國
+圍
+園
+圓
+圖
+團
+土
+在
+地
+均
+坊
+坐
+坑
+坵
+型
+垢
+城
+域
+埴
+執
+培
+基
+堂
+堅
+堆
+堤
+堯
+報
+場
+塔
+塚
+塞
+塵
+境
+墜
+墟
+墨
+墳
+墾
+壁
+壇
+壓
+壤
+士
+壬
+壯
+壺
+壽
+夏
+夕
+外
+多
+夜
+夢
+大
+天
+太
+夫
+央
+失
+夷
+奄
+奇
+奉
+奎
+奏
+契
+奔
+奮
+女
+奴
+好
+如
+妄
+妊
+妖
+妙
+始
+姑
+姓
+姚
+姜
+威
+婆
+婚
+婦
+媒
+媚
+子
+孔
+字
+存
+孝
+孟
+季
+孤
+孫
+學
+孺
+宇
+守
+安
+宋
+宗
+官
+宙
+定
+客
+宣
+室
+宮
+害
+家
+容
+寂
+寃
+寄
+寅
+密
+寇
+富
+寒
+寓
+實
+審
+寫
+寬
+寶
+寸
+寺
+封
+將
+專
+尊
+對
+小
+少
+尙
+尹
+尼
+尿
+局
+居
+屈
+屋
+屍
+屎
+屛
+層
+屬
+山
+岐
+岡
+岩
+岳
+岸
+峙
+峰
+島
+峻
+峽
+崇
+崔
+崖
+崩
+嶋
+巖
+川
+州
+巢
+工
+左
+巧
+巨
+巫
+差
+己
+巷
+市
+布
+帝
+師
+帶
+常
+帽
+幕
+干
+平
+年
+幹
+幻
+幼
+幽
+庇
+序
+店
+府
+度
+座
+庫
+庭
+康
+廟
+廣
+廳
+延
+廷
+建
+廻
+弁
+式
+弑
+弓
+引
+弘
+弟
+弱
+張
+强
+弼
+彌
+彛
+形
+彬
+影
+役
+彼
+彿
+往
+征
+待
+律
+後
+徐
+徑
+得
+從
+循
+微
+德
+徹
+心
+必
+忌
+忍
+志
+忠
+思
+怡
+急
+性
+恐
+恒
+恨
+恩
+悅
+悖
+患
+悲
+情
+惑
+惟
+惠
+惡
+想
+惺
+愁
+意
+愚
+愛
+感
+愼
+慈
+態
+慕
+慣
+慧
+慾
+憂
+憤
+憺
+應
+懸
+戎
+成
+我
+戟
+戮
+戰
+戴
+戶
+房
+所
+手
+才
+打
+批
+承
+技
+抄
+把
+抗
+抱
+抽
+拇
+拓
+拘
+拙
+拜
+拾
+持
+指
+捌
+捨
+捿
+授
+掌
+排
+接
+推
+提
+揚
+揭
+援
+損
+搗
+摩
+播
+操
+擒
+擔
+擘
+據
+擧
+攘
+攝
+攬
+支
+改
+攻
+放
+政
+故
+敍
+敎
+救
+敗
+散
+敬
+整
+數
+文
+斗
+料
+斛
+斜
+斧
+斯
+新
+斷
+方
+於
+施
+旋
+族
+旗
+日
+旨
+早
+旱
+昌
+明
+易
+昔
+星
+春
+昧
+昭
+是
+時
+晉
+晋
+晩
+普
+景
+晴
+晶
+智
+暈
+暑
+暗
+暘
+曉
+曜
+曠
+曦
+曰
+曲
+書
+曹
+曼
+曾
+最
+會
+月
+有
+朋
+服
+望
+朝
+期
+木
+未
+末
+本
+朱
+朴
+李
+材
+村
+杖
+杜
+杞
+杭
+杯
+東
+松
+板
+林
+果
+枝
+枯
+枰
+枾
+柏
+柑
+柱
+栗
+校
+栢
+核
+根
+格
+桀
+桂
+案
+桎
+桑
+桓
+桔
+梁
+梏
+梓
+梗
+條
+梨
+梵
+棗
+棟
+森
+植
+椒
+楊
+楓
+楚
+業
+楮
+極
+榮
+槃
+槍
+樂
+樓
+樗
+樣
+樸
+樹
+樺
+樽
+橄
+橋
+橘
+機
+橡
+檀
+檎
+權
+欌
+欖
+次
+欲
+歌
+歐
+止
+正
+此
+步
+武
+歲
+歸
+死
+殖
+段
+殷
+殺
+殿
+毅
+母
+毒
+比
+毛
+氏
+民
+氣
+水
+永
+求
+汎
+汗
+江
+池
+沅
+沒
+沖
+沙
+沛
+河
+油
+治
+沼
+沿
+泉
+泊
+法
+泗
+泡
+波
+注
+泰
+洋
+洙
+洛
+洞
+津
+洲
+活
+派
+流
+浅
+浦
+浮
+浴
+海
+涅
+涇
+消
+涌
+液
+淑
+淡
+淨
+淫
+深
+淳
+淵
+淸
+渠
+渡
+游
+渾
+湖
+湯
+源
+溪
+溫
+溶
+滄
+滅
+滋
+滯
+滿
+漁
+漆
+漢
+漫
+漸
+潑
+潤
+潭
+澄
+澎
+澤
+澳
+澹
+濁
+濕
+濟
+濤
+濯
+瀋
+瀝
+灣
+火
+灰
+灸
+災
+炎
+炭
+点
+烈
+烏
+烙
+焚
+無
+焦
+然
+煌
+煎
+照
+煬
+煮
+熟
+熱
+燁
+燈
+燔
+燕
+燥
+燧
+燮
+爲
+爵
+父
+片
+版
+牌
+牛
+牝
+牟
+牡
+物
+特
+犧
+犬
+狀
+狗
+猥
+猩
+猪
+獨
+獵
+獸
+獻
+玄
+玉
+王
+玲
+珍
+珠
+珪
+班
+現
+球
+理
+琴
+瑞
+瑟
+瑪
+璃
+璋
+璽
+瓜
+瓦
+甑
+甘
+生
+産
+用
+甫
+田
+由
+甲
+申
+男
+界
+畏
+留
+畜
+畢
+略
+番
+異
+畵
+當
+畸
+疏
+疑
+疫
+疹
+疼
+病
+症
+痔
+痛
+痺
+瘀
+瘍
+瘡
+療
+癌
+癖
+登
+發
+白
+百
+的
+皆
+皇
+皮
+盂
+盆
+益
+盛
+盜
+盟
+盡
+盤
+盧
+目
+直
+相
+省
+看
+眞
+眼
+睡
+督
+瞋
+矢
+矣
+知
+短
+石
+破
+碍
+碑
+磁
+磨
+磬
+示
+社
+祇
+祖
+祝
+神
+祥
+祭
+祺
+禁
+禅
+禍
+福
+禦
+禪
+禮
+禹
+禽
+禾
+秀
+私
+秉
+秋
+科
+秘
+秤
+秦
+秩
+移
+稀
+稗
+種
+稱
+稷
+稼
+稽
+穀
+穆
+積
+空
+窮
+竅
+立
+章
+童
+竭
+端
+竹
+笑
+符
+第
+筆
+等
+筍
+答
+策
+箋
+箕
+管
+箱
+節
+篇
+簡
+米
+粉
+粘
+粥
+精
+糖
+糞
+系
+紀
+紂
+約
+紅
+紋
+純
+紙
+級
+素
+索
+紫
+紬
+累
+細
+紳
+終
+組
+結
+絡
+統
+絲
+絶
+絹
+經
+綠
+維
+綱
+網
+綸
+綽
+緖
+線
+緣
+緯
+縣
+縱
+總
+織
+繡
+繩
+繪
+繭
+纂
+續
+罕
+置
+罰
+羅
+羊
+美
+群
+義
+羽
+翁
+習
+翟
+老
+考
+者
+而
+耐
+耕
+耳
+聃
+聖
+聞
+聰
+聲
+職
+肇
+肉
+肖
+肝
+股
+肥
+育
+肺
+胃
+胎
+胚
+胞
+胡
+胥
+能
+脂
+脈
+脚
+脛
+脣
+脩
+脫
+脯
+脾
+腋
+腎
+腫
+腸
+腹
+膜
+膠
+膨
+膽
+臆
+臟
+臣
+臥
+臨
+自
+至
+致
+臺
+臼
+臾
+與
+興
+舊
+舌
+舍
+舒
+舜
+舟
+般
+船
+艦
+良
+色
+芋
+花
+芳
+芽
+苑
+苔
+苕
+苛
+苞
+若
+苦
+英
+茂
+茵
+茶
+茹
+荀
+荇
+草
+荒
+荷
+莊
+莫
+菊
+菌
+菜
+菩
+菫
+華
+菴
+菽
+萊
+萍
+萬
+落
+葉
+著
+葛
+董
+葬
+蒙
+蒜
+蒲
+蒸
+蒿
+蓮
+蔓
+蔘
+蔡
+蔬
+蕃
+蕉
+蕓
+薄
+薑
+薛
+薩
+薪
+薺
+藏
+藝
+藤
+藥
+藩
+藻
+蘆
+蘇
+蘊
+蘚
+蘭
+虎
+處
+虛
+虞
+虹
+蜀
+蜂
+蜜
+蝕
+蝶
+融
+蟬
+蟲
+蠶
+蠻
+血
+衆
+行
+術
+衛
+衡
+衣
+表
+袁
+裔
+裕
+裙
+補
+製
+複
+襄
+西
+要
+見
+視
+親
+覺
+觀
+角
+解
+言
+訂
+訊
+訓
+託
+記
+訣
+設
+診
+註
+評
+詩
+話
+詵
+誅
+誌
+認
+誕
+語
+誠
+誤
+誥
+誦
+說
+調
+談
+諍
+論
+諡
+諫
+諭
+諸
+謙
+講
+謝
+謠
+證
+識
+譚
+譜
+譯
+議
+護
+讀
+變
+谷
+豆
+豊
+豚
+象
+豪
+豫
+貝
+貞
+財
+貧
+貨
+貪
+貫
+貴
+貸
+費
+資
+賊
+賓
+賞
+賢
+賣
+賦
+質
+贍
+赤
+赫
+走
+起
+超
+越
+趙
+趣
+趨
+足
+趾
+跋
+跡
+路
+踏
+蹟
+身
+躬
+車
+軍
+軒
+軟
+載
+輓
+輕
+輪
+輯
+輸
+輻
+輿
+轅
+轉
+辨
+辭
+辯
+辰
+農
+近
+迦
+述
+追
+逆
+透
+逐
+通
+逝
+造
+逢
+連
+進
+逵
+遂
+遊
+運
+遍
+過
+道
+達
+遠
+遡
+適
+遷
+選
+遺
+遽
+還
+邊
+邑
+那
+邪
+郞
+郡
+部
+都
+鄒
+鄕
+鄭
+鄲
+配
+酒
+酸
+醉
+醫
+醯
+釋
+里
+重
+野
+量
+釐
+金
+針
+鈍
+鈴
+鉞
+銀
+銅
+銘
+鋼
+錄
+錢
+錦
+鎭
+鏡
+鐘
+鐵
+鑑
+鑛
+長
+門
+閃
+開
+間
+閔
+閣
+閥
+閭
+閻
+闕
+關
+阪
+防
+阿
+陀
+降
+限
+陝
+院
+陰
+陳
+陵
+陶
+陸
+陽
+隆
+隊
+隋
+階
+際
+障
+隣
+隨
+隱
+隷
+雀
+雄
+雅
+集
+雇
+雌
+雖
+雙
+雜
+離
+難
+雨
+雪
+雲
+電
+霜
+露
+靈
+靑
+靖
+靜
+非
+面
+革
+靴
+鞏
+韓
+音
+韶
+韻
+順
+須
+頊
+頌
+領
+頭
+顔
+願
+顚
+類
+顯
+風
+飛
+食
+飢
+飮
+飯
+飾
+養
+餓
+餘
+首
+香
+馨
+馬
+駒
+騫
+騷
+驕
+骨
+骸
+髓
+體
+高
+髥
+髮
+鬪
+鬱
+鬼
+魏
+魔
+魚
+魯
+鮮
+鰍
+鰐
+鳥
+鳧
+鳳
+鴨
+鵲
+鶴
+鷄
+鷹
+鹽
+鹿
+麗
+麥
+麻
+黃
+黑
+默
+點
+黨
+鼎
+齊
+齋
+齒
+龍
+龜
+가
+각
+간
+갇
+갈
+갉
+감
+갑
+값
+갓
+갔
+강
+갖
+갗
+같
+갚
+갛
+개
+객
+갠
+갤
+갬
+갭
+갯
+갰
+갱
+갸
+걀
+걔
+걘
+거
+걱
+건
+걷
+걸
+검
+겁
+것
+겄
+겅
+겆
+겉
+겊
+겋
+게
+겐
+겔
+겟
+겠
+겡
+겨
+격
+겪
+견
+결
+겸
+겹
+겻
+겼
+경
+곁
+계
+곕
+곗
+고
+곡
+곤
+곧
+골
+곪
+곬
+곯
+곰
+곱
+곳
+공
+곶
+과
+곽
+관
+괄
+괌
+광
+괘
+괜
+괭
+괴
+괸
+굉
+교
+구
+국
+군
+굳
+굴
+굵
+굶
+굼
+굽
+굿
+궁
+궂
+궈
+권
+궐
+궜
+궝
+궤
+귀
+귄
+귈
+귓
+규
+균
+귤
+그
+극
+근
+글
+긁
+금
+급
+긋
+긍
+기
+긴
+길
+김
+깁
+깃
+깅
+깊
+까
+깍
+깎
+깐
+깔
+깜
+깝
+깟
+깡
+깥
+깨
+깬
+깰
+깻
+깼
+깽
+꺄
+꺼
+꺽
+꺾
+껀
+껄
+껌
+껍
+껏
+껐
+껑
+께
+껴
+꼈
+꼍
+꼐
+꼬
+꼭
+꼴
+꼼
+꼽
+꼿
+꽁
+꽂
+꽃
+꽉
+꽝
+꽤
+꽥
+꾀
+꾜
+꾸
+꾹
+꾼
+꿀
+꿇
+꿈
+꿉
+꿋
+꿍
+꿎
+꿔
+꿨
+꿩
+꿰
+꿴
+뀄
+뀌
+뀐
+뀔
+뀜
+뀝
+끄
+끈
+끊
+끌
+끓
+끔
+끕
+끗
+끙
+끝
+끼
+끽
+낀
+낄
+낌
+낍
+낏
+낑
+나
+낙
+낚
+난
+낟
+날
+낡
+남
+납
+낫
+났
+낭
+낮
+낯
+낱
+낳
+내
+낵
+낸
+낼
+냄
+냅
+냇
+냈
+냉
+냐
+냔
+냘
+냥
+너
+넉
+넋
+넌
+널
+넓
+넘
+넙
+넛
+넜
+넝
+넣
+네
+넥
+넨
+넬
+넴
+넵
+넷
+넸
+넹
+녀
+녁
+년
+념
+녔
+녕
+녘
+녜
+노
+녹
+논
+놀
+놈
+놋
+농
+높
+놓
+놔
+놨
+뇌
+뇨
+뇩
+뇽
+누
+눅
+눈
+눌
+눔
+눕
+눗
+눠
+눴
+뉘
+뉜
+뉩
+뉴
+늄
+늅
+늉
+느
+늑
+는
+늘
+늙
+늠
+늡
+능
+늦
+늪
+늬
+니
+닉
+닌
+닐
+님
+닙
+닛
+닝
+닢
+다
+닥
+닦
+단
+닫
+달
+닭
+닮
+닯
+닳
+담
+답
+닷
+당
+닻
+닿
+대
+댁
+댄
+댈
+댐
+댑
+댓
+댔
+댕
+댜
+더
+덕
+덖
+던
+덜
+덟
+덤
+덥
+덧
+덩
+덫
+덮
+데
+덱
+덴
+델
+뎀
+뎃
+뎅
+뎌
+뎠
+뎨
+도
+독
+돈
+돋
+돌
+돔
+돕
+돗
+동
+돛
+돝
+돼
+됐
+되
+된
+될
+됨
+됩
+됴
+두
+둑
+둔
+둘
+둠
+둡
+둣
+둥
+둬
+뒀
+뒤
+뒬
+뒷
+뒹
+듀
+듈
+듐
+드
+득
+든
+듣
+들
+듦
+듬
+듭
+듯
+등
+듸
+디
+딕
+딘
+딛
+딜
+딤
+딥
+딧
+딨
+딩
+딪
+따
+딱
+딴
+딸
+땀
+땄
+땅
+때
+땐
+땔
+땜
+땝
+땠
+땡
+떠
+떡
+떤
+떨
+떫
+떰
+떱
+떳
+떴
+떵
+떻
+떼
+떽
+뗀
+뗄
+뗍
+뗏
+뗐
+뗑
+또
+똑
+똘
+똥
+뙤
+뚜
+뚝
+뚤
+뚫
+뚱
+뛰
+뛴
+뛸
+뜀
+뜁
+뜨
+뜩
+뜬
+뜯
+뜰
+뜸
+뜻
+띄
+띈
+띌
+띔
+띕
+띠
+띤
+띨
+띱
+띵
+라
+락
+란
+랄
+람
+랍
+랏
+랐
+랑
+랒
+랗
+래
+랙
+랜
+랠
+램
+랩
+랫
+랬
+랭
+랴
+략
+량
+러
+럭
+런
+럴
+럼
+럽
+럿
+렀
+렁
+렇
+레
+렉
+렌
+렐
+렘
+렙
+렛
+렝
+려
+력
+련
+렬
+렴
+렵
+렷
+렸
+령
+례
+로
+록
+론
+롤
+롬
+롭
+롯
+롱
+롸
+롹
+뢰
+뢴
+뢸
+룃
+료
+룐
+룡
+루
+룩
+룬
+룰
+룸
+룹
+룻
+룽
+뤄
+뤘
+뤼
+류
+륙
+륜
+률
+륨
+륭
+르
+륵
+른
+를
+름
+릅
+릇
+릉
+릎
+리
+릭
+린
+릴
+림
+립
+릿
+링
+마
+막
+만
+많
+맏
+말
+맑
+맘
+맙
+맛
+망
+맞
+맡
+맣
+매
+맥
+맨
+맬
+맴
+맵
+맷
+맸
+맹
+맺
+먀
+먁
+머
+먹
+먼
+멀
+멈
+멋
+멍
+멎
+메
+멕
+멘
+멜
+멤
+멥
+멧
+멩
+며
+멱
+면
+멸
+몄
+명
+몇
+모
+목
+몫
+몬
+몰
+몸
+몹
+못
+몽
+뫼
+묘
+무
+묵
+묶
+문
+묻
+물
+묽
+뭄
+뭅
+뭇
+뭉
+뭍
+뭏
+뭐
+뭔
+뭘
+뭡
+뭣
+뮈
+뮌
+뮐
+뮤
+뮬
+므
+믈
+믐
+미
+믹
+민
+믿
+밀
+밈
+밉
+밋
+밌
+밍
+및
+밑
+바
+박
+밖
+반
+받
+발
+밝
+밟
+밤
+밥
+밧
+방
+밭
+배
+백
+밴
+밸
+뱀
+뱁
+뱃
+뱄
+뱅
+뱉
+뱍
+뱐
+버
+벅
+번
+벌
+범
+법
+벗
+벙
+벚
+베
+벡
+벤
+벨
+벰
+벱
+벳
+벵
+벼
+벽
+변
+별
+볍
+볏
+볐
+병
+볕
+보
+복
+볶
+본
+볼
+봄
+봅
+봇
+봉
+봐
+봤
+뵈
+뵐
+뵙
+부
+북
+분
+붇
+불
+붉
+붐
+붓
+붕
+붙
+뷔
+뷰
+뷴
+뷸
+브
+븐
+블
+비
+빅
+빈
+빌
+빔
+빕
+빗
+빙
+빚
+빛
+빠
+빡
+빤
+빨
+빳
+빴
+빵
+빻
+빼
+빽
+뺀
+뺄
+뺌
+뺏
+뺐
+뺑
+뺨
+뻐
+뻑
+뻔
+뻗
+뻘
+뻣
+뻤
+뻥
+뻬
+뼈
+뼉
+뼘
+뽀
+뽈
+뽐
+뽑
+뽕
+뾰
+뿌
+뿍
+뿐
+뿔
+뿜
+쁘
+쁜
+쁠
+쁨
+삐
+삔
+삘
+사
+삭
+삯
+산
+살
+삵
+삶
+삼
+삽
+삿
+샀
+상
+샅
+새
+색
+샌
+샐
+샘
+샙
+샛
+샜
+생
+샤
+샨
+샬
+샴
+샵
+샷
+샹
+서
+석
+섞
+선
+섣
+설
+섬
+섭
+섯
+섰
+성
+섶
+세
+섹
+센
+셀
+셈
+셉
+셋
+셌
+셍
+셔
+션
+셜
+셨
+셰
+셴
+셸
+소
+속
+손
+솔
+솜
+솝
+솟
+송
+솥
+쇄
+쇠
+쇤
+쇳
+쇼
+숀
+숄
+숍
+수
+숙
+순
+숟
+술
+숨
+숩
+숫
+숭
+숯
+숱
+숲
+숴
+쉐
+쉘
+쉬
+쉭
+쉰
+쉴
+쉼
+쉽
+슈
+슐
+슘
+슛
+슝
+스
+슥
+슨
+슬
+슭
+슴
+습
+슷
+승
+시
+식
+신
+싣
+실
+싫
+심
+십
+싯
+싱
+싶
+싸
+싹
+싼
+쌀
+쌈
+쌉
+쌌
+쌍
+쌓
+쌔
+쌘
+쌩
+써
+썩
+썬
+썰
+썸
+썹
+썼
+썽
+쎄
+쎈
+쏘
+쏙
+쏜
+쏟
+쏠
+쏭
+쏴
+쐈
+쐐
+쐬
+쑤
+쑥
+쑨
+쒀
+쒔
+쓰
+쓱
+쓴
+쓸
+씀
+씁
+씌
+씨
+씩
+씬
+씰
+씸
+씹
+씻
+씽
+아
+악
+안
+앉
+않
+알
+앎
+앓
+암
+압
+앗
+았
+앙
+앞
+애
+액
+앤
+앨
+앰
+앱
+앳
+앴
+앵
+야
+약
+얀
+얄
+얇
+얌
+얍
+얏
+양
+얕
+얗
+얘
+얜
+어
+억
+언
+얹
+얻
+얼
+얽
+엄
+업
+없
+엇
+었
+엉
+엊
+엌
+엎
+에
+엑
+엔
+엘
+엠
+엡
+엣
+엥
+여
+역
+엮
+연
+열
+엷
+염
+엽
+엾
+엿
+였
+영
+옅
+옆
+옇
+예
+옌
+옐
+옙
+옛
+오
+옥
+온
+올
+옭
+옮
+옳
+옴
+옵
+옷
+옹
+옻
+와
+왁
+완
+왈
+왑
+왓
+왔
+왕
+왜
+왠
+왱
+외
+왼
+요
+욕
+욘
+욜
+욤
+용
+우
+욱
+운
+울
+움
+웁
+웃
+웅
+워
+웍
+원
+월
+웜
+웠
+웡
+웨
+웬
+웰
+웸
+웹
+위
+윅
+윈
+윌
+윔
+윗
+윙
+유
+육
+윤
+율
+윱
+윳
+융
+으
+윽
+은
+을
+읊
+음
+읍
+응
+의
+읜
+읠
+이
+익
+인
+일
+읽
+잃
+임
+입
+잇
+있
+잉
+잊
+잎
+자
+작
+잔
+잖
+잘
+잠
+잡
+잣
+잤
+장
+잦
+재
+잭
+잰
+잴
+잽
+잿
+쟀
+쟁
+쟈
+쟉
+쟤
+저
+적
+전
+절
+젊
+점
+접
+젓
+정
+젖
+제
+젝
+젠
+젤
+젬
+젭
+젯
+져
+젼
+졀
+졌
+졍
+조
+족
+존
+졸
+좀
+좁
+종
+좇
+좋
+좌
+좍
+좽
+죄
+죠
+죤
+주
+죽
+준
+줄
+줌
+줍
+줏
+중
+줘
+줬
+쥐
+쥔
+쥘
+쥬
+쥴
+즈
+즉
+즌
+즐
+즘
+즙
+증
+지
+직
+진
+짇
+질
+짊
+짐
+집
+짓
+징
+짖
+짙
+짚
+짜
+짝
+짠
+짢
+짤
+짧
+짬
+짭
+짰
+짱
+째
+짹
+짼
+쨀
+쨉
+쨋
+쨌
+쨍
+쩄
+쩌
+쩍
+쩐
+쩔
+쩜
+쩝
+쩡
+쩨
+쪄
+쪘
+쪼
+쪽
+쪾
+쫀
+쫄
+쫑
+쫓
+쫙
+쬐
+쭈
+쭉
+쭐
+쭙
+쯔
+쯤
+쯧
+찌
+찍
+찐
+찔
+찜
+찝
+찡
+찢
+찧
+차
+착
+찬
+찮
+찰
+참
+찹
+찻
+찼
+창
+찾
+채
+책
+챈
+챌
+챔
+챕
+챗
+챘
+챙
+챠
+챤
+처
+척
+천
+철
+첨
+첩
+첫
+청
+체
+첵
+첸
+첼
+쳄
+쳇
+쳉
+쳐
+쳔
+쳤
+초
+촉
+촌
+촘
+촛
+총
+촨
+촬
+최
+쵸
+추
+축
+춘
+출
+춤
+춥
+춧
+충
+춰
+췄
+췌
+취
+췬
+츄
+츠
+측
+츨
+츰
+층
+치
+칙
+친
+칠
+칡
+침
+칩
+칫
+칭
+카
+칵
+칸
+칼
+캄
+캅
+캇
+캉
+캐
+캔
+캘
+캠
+캡
+캣
+캤
+캥
+캬
+커
+컥
+컨
+컫
+컬
+컴
+컵
+컷
+컸
+컹
+케
+켄
+켈
+켐
+켓
+켕
+켜
+켠
+켤
+켭
+켯
+켰
+코
+콕
+콘
+콜
+콤
+콥
+콧
+콩
+콰
+콱
+콴
+콸
+쾅
+쾌
+쾡
+쾨
+쾰
+쿄
+쿠
+쿡
+쿤
+쿨
+쿰
+쿵
+쿼
+퀀
+퀄
+퀘
+퀭
+퀴
+퀵
+퀸
+퀼
+큐
+큘
+크
+큰
+클
+큼
+큽
+키
+킥
+킨
+킬
+킴
+킵
+킷
+킹
+타
+탁
+탄
+탈
+탉
+탐
+탑
+탓
+탔
+탕
+태
+택
+탠
+탤
+탬
+탭
+탯
+탰
+탱
+터
+턱
+턴
+털
+텀
+텁
+텃
+텄
+텅
+테
+텍
+텐
+텔
+템
+텝
+텡
+텨
+톈
+토
+톡
+톤
+톨
+톰
+톱
+톳
+통
+퇴
+툇
+투
+툭
+툰
+툴
+툼
+퉁
+퉈
+퉜
+튀
+튄
+튈
+튕
+튜
+튠
+튤
+튬
+트
+특
+튼
+튿
+틀
+틈
+틉
+틋
+틔
+티
+틱
+틴
+틸
+팀
+팁
+팅
+파
+팍
+팎
+판
+팔
+팜
+팝
+팟
+팠
+팡
+팥
+패
+팩
+팬
+팰
+팸
+팻
+팼
+팽
+퍼
+퍽
+펀
+펄
+펌
+펍
+펐
+펑
+페
+펙
+펜
+펠
+펨
+펩
+펫
+펭
+펴
+편
+펼
+폄
+폈
+평
+폐
+포
+폭
+폰
+폴
+폼
+폿
+퐁
+표
+푭
+푸
+푹
+푼
+풀
+품
+풋
+풍
+퓨
+퓬
+퓰
+퓸
+프
+픈
+플
+픔
+픕
+피
+픽
+핀
+필
+핌
+핍
+핏
+핑
+하
+학
+한
+할
+핥
+함
+합
+핫
+항
+해
+핵
+핸
+핼
+햄
+햅
+햇
+했
+행
+햐
+향
+헀
+허
+헉
+헌
+헐
+험
+헙
+헛
+헝
+헤
+헥
+헨
+헬
+헴
+헵
+헷
+헹
+혀
+혁
+현
+혈
+혐
+협
+혓
+혔
+형
+혜
+호
+혹
+혼
+홀
+홈
+홉
+홋
+홍
+홑
+화
+확
+환
+활
+홧
+황
+홰
+홱
+횃
+회
+획
+횝
+횟
+횡
+효
+후
+훅
+훈
+훌
+훑
+훔
+훗
+훤
+훨
+훼
+휄
+휑
+휘
+휙
+휜
+휠
+휩
+휭
+휴
+휼
+흄
+흉
+흐
+흑
+흔
+흘
+흙
+흠
+흡
+흣
+흥
+흩
+희
+흰
+흽
+히
+힉
+힌
+힐
+힘
+힙
+힝
+車
+滑
+金
+奈
+羅
+洛
+卵
+欄
+蘭
+郎
+來
+盧
+老
+魯
+綠
+鹿
+論
+雷
+樓
+縷
+凌
+樂
+不
+參
+葉
+沈
+若
+兩
+凉
+梁
+呂
+女
+廬
+麗
+黎
+曆
+歷
+戀
+蓮
+連
+列
+烈
+裂
+念
+獵
+靈
+領
+例
+禮
+醴
+惡
+尿
+料
+遼
+龍
+暈
+柳
+流
+類
+六
+陸
+倫
+律
+栗
+利
+李
+梨
+理
+離
+燐
+林
+臨
+立
+茶
+切
+宅
+
diff --git a/tools/utils/dict/latex_symbol_dict.txt b/tools/utils/dict/latex_symbol_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d17f2c2cf79cf50fa83c17e087bb77a79a862b39
--- /dev/null
+++ b/tools/utils/dict/latex_symbol_dict.txt
@@ -0,0 +1,111 @@
+eos
+sos
+!
+'
+(
+)
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+<
+=
+>
+A
+B
+C
+E
+F
+G
+H
+I
+L
+M
+N
+P
+R
+S
+T
+V
+X
+Y
+[
+\Delta
+\alpha
+\beta
+\cdot
+\cdots
+\cos
+\div
+\exists
+\forall
+\frac
+\gamma
+\geq
+\in
+\infty
+\int
+\lambda
+\ldots
+\leq
+\lim
+\log
+\mu
+\neq
+\phi
+\pi
+\pm
+\prime
+\rightarrow
+\sigma
+\sin
+\sqrt
+\sum
+\tan
+\theta
+\times
+]
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+\{
+|
+\}
+{
+}
+^
+_
\ No newline at end of file
diff --git a/tools/utils/dict/latin_dict.txt b/tools/utils/dict/latin_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e166bf33ecfbdc90ddb3d9743fded23306acabd5
--- /dev/null
+++ b/tools/utils/dict/latin_dict.txt
@@ -0,0 +1,185 @@
+
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+[
+]
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+}
+¡
+£
+§
+ª
+«
+
+°
+²
+³
+´
+µ
+·
+º
+»
+¿
+À
+Á
+Â
+Ä
+Å
+Ç
+È
+É
+Ê
+Ë
+Ì
+Í
+Î
+Ï
+Ò
+Ó
+Ô
+Õ
+Ö
+Ú
+Ü
+Ý
+ß
+à
+á
+â
+ã
+ä
+å
+æ
+ç
+è
+é
+ê
+ë
+ì
+í
+î
+ï
+ñ
+ò
+ó
+ô
+õ
+ö
+ø
+ù
+ú
+û
+ü
+ý
+ą
+Ć
+ć
+Č
+č
+Đ
+đ
+ę
+ı
+Ł
+ł
+ō
+Œ
+œ
+Š
+š
+Ÿ
+Ž
+ž
+ʒ
+β
+δ
+ε
+з
+Ṡ
+‘
+€
+™
diff --git a/tools/utils/dict/layout_dict/layout_cdla_dict.txt b/tools/utils/dict/layout_dict/layout_cdla_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8be0f48600a88463d840fffe27eebd63629576ce
--- /dev/null
+++ b/tools/utils/dict/layout_dict/layout_cdla_dict.txt
@@ -0,0 +1,10 @@
+text
+title
+figure
+figure_caption
+table
+table_caption
+header
+footer
+reference
+equation
\ No newline at end of file
diff --git a/tools/utils/dict/layout_dict/layout_publaynet_dict.txt b/tools/utils/dict/layout_dict/layout_publaynet_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ca6acf4eef8d4d5f9ba5a4ced4858a119a4ef983
--- /dev/null
+++ b/tools/utils/dict/layout_dict/layout_publaynet_dict.txt
@@ -0,0 +1,5 @@
+text
+title
+list
+table
+figure
\ No newline at end of file
diff --git a/tools/utils/dict/layout_dict/layout_table_dict.txt b/tools/utils/dict/layout_dict/layout_table_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..faea15ea07d7d1a6f77dbd4287bb9fa87165cbb9
--- /dev/null
+++ b/tools/utils/dict/layout_dict/layout_table_dict.txt
@@ -0,0 +1 @@
+table
\ No newline at end of file
diff --git a/tools/utils/dict/mr_dict.txt b/tools/utils/dict/mr_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..283b1504ae344ed7db95050ddb9e3682126cc741
--- /dev/null
+++ b/tools/utils/dict/mr_dict.txt
@@ -0,0 +1,153 @@
+
+!
+#
+$
+%
+&
+'
+(
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+É
+é
+ँ
+ं
+ः
+अ
+आ
+इ
+ई
+उ
+ऊ
+ए
+ऐ
+ऑ
+ओ
+औ
+क
+ख
+ग
+घ
+च
+छ
+ज
+झ
+ञ
+ट
+ठ
+ड
+ढ
+ण
+त
+थ
+द
+ध
+न
+प
+फ
+ब
+भ
+म
+य
+र
+ऱ
+ल
+ळ
+व
+श
+ष
+स
+ह
+़
+ा
+ि
+ी
+ु
+ू
+ृ
+ॅ
+े
+ै
+ॉ
+ो
+ौ
+्
+०
+१
+२
+३
+४
+५
+६
+७
+८
+९
diff --git a/tools/utils/dict/ne_dict.txt b/tools/utils/dict/ne_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5a7df9537fc886c4de53ee68ab67b171b386780f
--- /dev/null
+++ b/tools/utils/dict/ne_dict.txt
@@ -0,0 +1,153 @@
+
+!
+#
+$
+%
+&
+'
+(
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+É
+é
+ः
+अ
+आ
+इ
+ई
+उ
+ऊ
+ऋ
+ए
+ऐ
+ओ
+औ
+क
+ख
+ग
+घ
+ङ
+च
+छ
+ज
+झ
+ञ
+ट
+ठ
+ड
+ढ
+ण
+त
+थ
+द
+ध
+न
+ऩ
+प
+फ
+ब
+भ
+म
+य
+र
+ऱ
+ल
+व
+श
+ष
+स
+ह
+़
+ा
+ि
+ी
+ु
+ू
+ृ
+े
+ै
+ो
+ौ
+्
+॒
+ॠ
+।
+०
+१
+२
+३
+४
+५
+६
+७
+८
+९
diff --git a/tools/utils/dict/oc_dict.txt b/tools/utils/dict/oc_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e88af8bd85f4e01c43d08e5ba3ae6cadc5a465a4
--- /dev/null
+++ b/tools/utils/dict/oc_dict.txt
@@ -0,0 +1,96 @@
+o
+c
+_
+i
+m
+g
+/
+2
+0
+I
+L
+S
+V
+R
+C
+1
+v
+a
+l
+4
+3
+.
+j
+p
+r
+e
+è
+t
+9
+7
+5
+8
+n
+'
+b
+s
+6
+q
+u
+á
+d
+ò
+à
+h
+z
+f
+ï
+í
+A
+ç
+x
+ó
+é
+P
+O
+Ò
+ü
+k
+À
+F
+-
+ú
+
+æ
+Á
+D
+E
+w
+K
+T
+N
+y
+U
+Z
+G
+B
+J
+H
+M
+W
+Y
+X
+Q
+%
+$
+,
+@
+&
+!
+:
+(
+#
+?
++
+É
+
diff --git a/tools/utils/dict/pu_dict.txt b/tools/utils/dict/pu_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9500fae6e4976ea632bf579b533f82f176f3b7e7
--- /dev/null
+++ b/tools/utils/dict/pu_dict.txt
@@ -0,0 +1,130 @@
+p
+u
+_
+i
+m
+g
+/
+8
+I
+L
+S
+V
+R
+C
+2
+0
+1
+v
+a
+l
+6
+7
+4
+5
+.
+j
+
+q
+e
+s
+t
+ã
+o
+x
+9
+c
+n
+r
+z
+ç
+õ
+3
+A
+U
+d
+º
+ô
+
+,
+E
+;
+ó
+á
+b
+D
+?
+ú
+ê
+-
+h
+P
+f
+à
+N
+í
+O
+M
+G
+É
+é
+â
+F
+:
+T
+Á
+"
+Q
+)
+W
+J
+B
+H
+(
+ö
+%
+Ö
+«
+w
+K
+y
+!
+k
+]
+'
+Z
++
+Ç
+Õ
+Y
+À
+X
+µ
+»
+ª
+Í
+ü
+ä
+´
+è
+ñ
+ß
+ï
+Ú
+ë
+Ô
+Ï
+Ó
+[
+Ì
+<
+Â
+ò
+§
+³
+ø
+å
+#
+$
+&
+@
diff --git a/tools/utils/dict/rs_dict.txt b/tools/utils/dict/rs_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d1ce46d240841b8471cbb1209ee92864895c667c
--- /dev/null
+++ b/tools/utils/dict/rs_dict.txt
@@ -0,0 +1,91 @@
+r
+s
+_
+i
+m
+g
+/
+1
+I
+L
+S
+V
+R
+C
+2
+0
+v
+a
+l
+7
+5
+8
+6
+.
+j
+p
+
+t
+d
+9
+3
+e
+š
+4
+k
+u
+ć
+c
+n
+đ
+o
+z
+č
+b
+ž
+f
+Z
+T
+h
+M
+F
+O
+Š
+B
+H
+A
+E
+Đ
+Ž
+D
+P
+G
+Č
+K
+U
+N
+J
+Ć
+w
+y
+W
+x
+Y
+X
+q
+Q
+#
+&
+$
+,
+-
+%
+'
+@
+!
+:
+?
+(
+É
+é
++
diff --git a/tools/utils/dict/rsc_dict.txt b/tools/utils/dict/rsc_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..95dd4636057e5b6dd8bd3a3dd6aacf19e790cffb
--- /dev/null
+++ b/tools/utils/dict/rsc_dict.txt
@@ -0,0 +1,134 @@
+r
+s
+c
+_
+i
+m
+g
+/
+5
+I
+L
+S
+V
+R
+C
+2
+0
+1
+v
+a
+l
+9
+7
+8
+.
+j
+p
+м
+а
+с
+и
+р
+ћ
+е
+ш
+3
+4
+о
+г
+н
+з
+в
+л
+6
+т
+ж
+у
+к
+п
+њ
+д
+ч
+С
+ј
+ф
+ц
+љ
+х
+О
+И
+А
+б
+Ш
+К
+ђ
+џ
+М
+В
+З
+Д
+Р
+У
+Н
+Т
+Б
+?
+П
+Х
+Ј
+Ц
+Г
+Љ
+Л
+Ф
+e
+n
+w
+E
+F
+A
+N
+f
+o
+b
+M
+G
+t
+y
+W
+k
+P
+u
+H
+B
+T
+z
+h
+O
+Y
+d
+U
+K
+D
+x
+X
+J
+Z
+Q
+q
+'
+-
+@
+é
+#
+!
+,
+%
+$
+:
+&
++
+(
+É
+
diff --git a/tools/utils/dict/ru_dict.txt b/tools/utils/dict/ru_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3b0cf3a8d6cd61ae395d1242dae3d42906029e2c
--- /dev/null
+++ b/tools/utils/dict/ru_dict.txt
@@ -0,0 +1,125 @@
+к
+в
+а
+з
+и
+у
+р
+о
+н
+я
+х
+п
+л
+ы
+г
+е
+т
+м
+д
+ж
+ш
+ь
+с
+ё
+б
+й
+ч
+ю
+ц
+щ
+М
+э
+ф
+А
+ъ
+С
+Ф
+Ю
+В
+К
+Т
+Н
+О
+Э
+У
+И
+Г
+Л
+Р
+Д
+Б
+Ш
+П
+З
+Х
+Е
+Ж
+Я
+Ц
+Ч
+Й
+Щ
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+
diff --git a/tools/utils/dict/spin_dict.txt b/tools/utils/dict/spin_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8ee8347fd9c85228a3cf46c810d4fc28ab05c492
--- /dev/null
+++ b/tools/utils/dict/spin_dict.txt
@@ -0,0 +1,68 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+:
+(
+'
+-
+,
+%
+>
+.
+[
+?
+)
+"
+=
+_
+*
+]
+;
+&
++
+$
+@
+/
+|
+!
+<
+#
+`
+{
+~
+\
+}
+^
\ No newline at end of file
diff --git a/tools/utils/dict/ta_dict.txt b/tools/utils/dict/ta_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..19d81892c205627f296adbf8b20ea41aba2de5d0
--- /dev/null
+++ b/tools/utils/dict/ta_dict.txt
@@ -0,0 +1,128 @@
+t
+a
+_
+i
+m
+g
+/
+3
+I
+L
+S
+V
+R
+C
+2
+0
+1
+v
+l
+9
+7
+8
+.
+j
+p
+ப
+ூ
+த
+ம
+ி
+வ
+ர
+்
+ந
+ோ
+ன
+6
+ஆ
+ற
+ல
+5
+ள
+ா
+ொ
+ழ
+ு
+4
+ெ
+ண
+க
+ட
+ை
+ே
+ச
+ய
+ஒ
+இ
+அ
+ங
+உ
+ீ
+ஞ
+எ
+ஓ
+ஃ
+ஜ
+ஷ
+ஸ
+ஏ
+ஊ
+ஹ
+ஈ
+ஐ
+ௌ
+ஔ
+s
+c
+e
+n
+w
+F
+T
+O
+P
+K
+A
+N
+G
+Y
+E
+M
+H
+U
+B
+o
+b
+D
+d
+r
+W
+u
+y
+f
+X
+k
+q
+h
+J
+z
+Z
+Q
+x
+-
+'
+$
+,
+%
+@
+é
+!
+#
++
+É
+&
+:
+(
+?
+
diff --git a/tools/utils/dict/table_dict.txt b/tools/utils/dict/table_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2ef028c786cbce6d1e25856c62986d757b31f93b
--- /dev/null
+++ b/tools/utils/dict/table_dict.txt
@@ -0,0 +1,277 @@
+←
+
+☆
+─
+α
+
+
+⋅
+$
+ω
+ψ
+χ
+(
+υ
+≥
+σ
+,
+ρ
+ε
+0
+■
+4
+8
+✗
+b
+<
+✓
+Ψ
+Ω
+€
+D
+3
+Π
+H
+║
+
+L
+Φ
+Χ
+θ
+P
+κ
+λ
+μ
+T
+ξ
+X
+β
+γ
+δ
+\
+ζ
+η
+`
+d
+
+h
+f
+l
+Θ
+p
+√
+t
+
+x
+Β
+Γ
+Δ
+|
+ǂ
+ɛ
+j
+̧
+➢
+
+̌
+′
+«
+△
+▲
+#
+
+'
+Ι
++
+¶
+/
+▼
+⇑
+□
+·
+7
+▪
+;
+?
+➔
+∩
+C
+÷
+G
+⇒
+K
+
+O
+S
+С
+W
+Α
+[
+○
+_
+●
+‡
+c
+z
+g
+
+o
+
+〈
+〉
+s
+⩽
+w
+φ
+ʹ
+{
+»
+∣
+̆
+e
+ˆ
+∈
+τ
+◆
+ι
+∅
+∆
+∙
+∘
+Ø
+ß
+✔
+∞
+∑
+−
+×
+◊
+∗
+∖
+˃
+˂
+∫
+"
+i
+&
+π
+↔
+*
+∥
+æ
+∧
+.
+⁄
+ø
+Q
+∼
+6
+⁎
+:
+★
+>
+a
+B
+≈
+F
+J
+̄
+N
+♯
+R
+V
+
+―
+Z
+♣
+^
+¤
+¥
+§
+
+¢
+£
+≦
+
+≤
+‖
+Λ
+©
+n
+↓
+→
+↑
+r
+°
+±
+v
+
+♂
+k
+♀
+~
+ᅟ
+̇
+@
+”
+♦
+ł
+®
+⊕
+„
+!
+
+%
+⇓
+)
+-
+1
+5
+9
+=
+А
+A
+‰
+⋆
+Σ
+E
+◦
+I
+※
+M
+m
+̨
+⩾
+†
+
+•
+U
+Y
+
+]
+̸
+2
+‐
+–
+‒
+̂
+—
+̀
+́
+’
+‘
+⋮
+⋯
+̊
+“
+̈
+≧
+q
+u
+ı
+y
+
+
+̃
+}
+ν
diff --git a/tools/utils/dict/table_master_structure_dict.txt b/tools/utils/dict/table_master_structure_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..95ab2539a70aca4f695c53a38cdc1c3e164fcfb3
--- /dev/null
+++ b/tools/utils/dict/table_master_structure_dict.txt
@@ -0,0 +1,39 @@
+
+
+ |
+
+
+
+
+
+
+ |
+ colspan="2"
+ colspan="3"
+
+
+ rowspan="2"
+ colspan="4"
+ colspan="6"
+ rowspan="3"
+ colspan="9"
+ colspan="10"
+ colspan="7"
+ rowspan="4"
+ rowspan="5"
+ rowspan="9"
+ colspan="8"
+ rowspan="8"
+ rowspan="6"
+ rowspan="7"
+ rowspan="10"
+
+
+
+
+
+
+
+
diff --git a/tools/utils/dict/table_structure_dict.txt b/tools/utils/dict/table_structure_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8edb10b8817ad596af6c63b6b8fc5eb2349b7464
--- /dev/null
+++ b/tools/utils/dict/table_structure_dict.txt
@@ -0,0 +1,28 @@
+
+
+
+ |
+
+
+
+
+
+ colspan="2"
+ colspan="3"
+ rowspan="2"
+ colspan="4"
+ colspan="6"
+ rowspan="3"
+ colspan="9"
+ colspan="10"
+ colspan="7"
+ rowspan="4"
+ rowspan="5"
+ rowspan="9"
+ colspan="8"
+ rowspan="8"
+ rowspan="6"
+ rowspan="7"
+ rowspan="10"
\ No newline at end of file
diff --git a/tools/utils/dict/table_structure_dict_ch.txt b/tools/utils/dict/table_structure_dict_ch.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0c59c0e9998a31f9d32f703625aa1c5ca7718c8d
--- /dev/null
+++ b/tools/utils/dict/table_structure_dict_ch.txt
@@ -0,0 +1,48 @@
+
+
+ |
+
+
+
+
+ |
+ |
+ colspan="2"
+ colspan="3"
+ colspan="4"
+ colspan="5"
+ colspan="6"
+ colspan="7"
+ colspan="8"
+ colspan="9"
+ colspan="10"
+ colspan="11"
+ colspan="12"
+ colspan="13"
+ colspan="14"
+ colspan="15"
+ colspan="16"
+ colspan="17"
+ colspan="18"
+ colspan="19"
+ colspan="20"
+ rowspan="2"
+ rowspan="3"
+ rowspan="4"
+ rowspan="5"
+ rowspan="6"
+ rowspan="7"
+ rowspan="8"
+ rowspan="9"
+ rowspan="10"
+ rowspan="11"
+ rowspan="12"
+ rowspan="13"
+ rowspan="14"
+ rowspan="15"
+ rowspan="16"
+ rowspan="17"
+ rowspan="18"
+ rowspan="19"
+ rowspan="20"
diff --git a/tools/utils/dict/te_dict.txt b/tools/utils/dict/te_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..83d74cc7e5f899ca43b23fa690d84d70bee535e3
--- /dev/null
+++ b/tools/utils/dict/te_dict.txt
@@ -0,0 +1,151 @@
+t
+e
+_
+i
+m
+g
+/
+5
+I
+L
+S
+V
+R
+C
+2
+0
+1
+v
+a
+l
+3
+4
+8
+9
+.
+j
+p
+త
+ె
+ర
+క
+్
+ి
+ం
+చ
+ే
+ద
+ు
+7
+6
+ఉ
+ా
+మ
+ట
+ో
+వ
+ప
+ల
+శ
+ఆ
+య
+ై
+భ
+'
+ీ
+గ
+ూ
+డ
+ధ
+హ
+న
+జ
+స
+[
+
+ష
+అ
+ణ
+ఫ
+బ
+ఎ
+;
+ళ
+థ
+ొ
+ఠ
+ృ
+ఒ
+ఇ
+ః
+ఊ
+ఖ
+-
+ఐ
+ఘ
+ౌ
+ఏ
+ఈ
+ఛ
+,
+ఓ
+ఞ
+|
+?
+:
+ఢ
+"
+(
+”
+!
++
+)
+*
+=
+&
+“
+€
+]
+£
+$
+s
+c
+n
+w
+k
+J
+G
+u
+d
+r
+E
+o
+h
+y
+b
+f
+B
+M
+O
+T
+N
+D
+P
+A
+F
+x
+W
+Y
+U
+H
+K
+X
+z
+Z
+Q
+q
+É
+%
+#
+@
+é
diff --git a/tools/utils/dict/ug_dict.txt b/tools/utils/dict/ug_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..77602f2cfd29d739478bc9e757bd82b71235554b
--- /dev/null
+++ b/tools/utils/dict/ug_dict.txt
@@ -0,0 +1,114 @@
+u
+g
+_
+i
+m
+/
+1
+I
+L
+S
+V
+R
+C
+2
+0
+v
+a
+l
+8
+5
+3
+6
+9
+.
+j
+p
+
+ق
+ا
+پ
+ل
+4
+7
+ئ
+ى
+ش
+ت
+ي
+ك
+د
+ف
+ر
+و
+ن
+ب
+ە
+خ
+ې
+چ
+ۇ
+ز
+س
+م
+ۋ
+گ
+ڭ
+ۆ
+ۈ
+ج
+غ
+ھ
+ژ
+s
+c
+e
+n
+w
+P
+E
+D
+U
+d
+r
+b
+y
+B
+o
+O
+Y
+N
+T
+k
+t
+h
+A
+H
+F
+z
+W
+K
+G
+M
+f
+Z
+X
+Q
+J
+x
+q
+-
+!
+%
+#
+?
+:
+$
+,
+&
+'
+É
+@
+é
+(
++
diff --git a/tools/utils/dict/uk_dict.txt b/tools/utils/dict/uk_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c5ffc0a53dbdf7af6d911097dffb8733d7d4eab1
--- /dev/null
+++ b/tools/utils/dict/uk_dict.txt
@@ -0,0 +1,142 @@
+u
+k
+_
+i
+m
+g
+/
+1
+6
+I
+L
+S
+V
+R
+C
+2
+0
+v
+a
+l
+7
+9
+.
+j
+p
+в
+і
+д
+п
+о
+н
+с
+т
+ю
+4
+5
+3
+а
+и
+м
+е
+р
+ч
+у
+Б
+з
+л
+к
+8
+А
+В
+г
+є
+б
+ь
+х
+ґ
+ш
+ц
+ф
+я
+щ
+ж
+Г
+Х
+У
+Т
+Е
+І
+Н
+П
+З
+Л
+Ю
+С
+Д
+М
+К
+Р
+Ф
+О
+Ц
+И
+Я
+Ч
+Ш
+Ж
+Є
+Ґ
+Ь
+s
+c
+e
+n
+w
+A
+P
+r
+E
+t
+o
+h
+d
+y
+M
+G
+N
+F
+B
+T
+D
+U
+O
+W
+Z
+f
+H
+Y
+b
+K
+z
+x
+Q
+X
+q
+J
+$
+-
+'
+#
+&
+%
+?
+:
+!
+,
++
+@
+(
+é
+É
+
diff --git a/tools/utils/dict/ur_dict.txt b/tools/utils/dict/ur_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c06786a83bc60039fa395d71e367bece1e80b11d
--- /dev/null
+++ b/tools/utils/dict/ur_dict.txt
@@ -0,0 +1,137 @@
+u
+r
+_
+i
+m
+g
+/
+3
+I
+L
+S
+V
+R
+C
+2
+0
+1
+v
+a
+l
+9
+7
+8
+.
+j
+p
+
+چ
+ٹ
+پ
+ا
+ئ
+ی
+ے
+4
+6
+و
+ل
+ن
+ڈ
+ھ
+ک
+ت
+ش
+ف
+ق
+ر
+د
+5
+ب
+ج
+خ
+ہ
+س
+ز
+غ
+ڑ
+ں
+آ
+م
+ؤ
+ط
+ص
+ح
+ع
+گ
+ث
+ض
+ذ
+ۓ
+ِ
+ء
+ظ
+ً
+ي
+ُ
+ۃ
+أ
+ٰ
+ە
+ژ
+ۂ
+ة
+ّ
+ك
+ه
+s
+c
+e
+n
+w
+o
+d
+t
+D
+M
+T
+U
+E
+b
+P
+h
+y
+W
+H
+A
+x
+B
+O
+N
+G
+Y
+Q
+F
+k
+K
+q
+J
+Z
+f
+z
+X
+'
+@
+&
+!
+,
+:
+$
+-
+#
+?
+%
+é
++
+(
+É
diff --git a/tools/utils/dict/xi_dict.txt b/tools/utils/dict/xi_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f195f1ea6ce90b09f5feecfc6c288eed423eeb8d
--- /dev/null
+++ b/tools/utils/dict/xi_dict.txt
@@ -0,0 +1,110 @@
+x
+i
+_
+m
+g
+/
+1
+0
+I
+L
+S
+V
+R
+C
+2
+v
+a
+l
+3
+6
+4
+5
+.
+j
+p
+
+Q
+u
+e
+r
+o
+8
+7
+n
+c
+9
+t
+b
+é
+q
+d
+ó
+y
+F
+s
+,
+O
+í
+T
+f
+"
+U
+M
+h
+:
+P
+H
+A
+E
+D
+z
+N
+á
+ñ
+ú
+%
+;
+è
++
+Y
+-
+B
+G
+(
+)
+¿
+?
+w
+¡
+!
+X
+É
+K
+k
+Á
+ü
+Ú
+«
+»
+J
+'
+ö
+W
+Z
+º
+Ö
+
+[
+]
+Ç
+ç
+à
+ä
+û
+ò
+Í
+ê
+ô
+ø
+ª
diff --git a/tools/utils/dict90.txt b/tools/utils/dict90.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a945ae9c526e4faa68852eb3fb47d078a2f3f6ce
--- /dev/null
+++ b/tools/utils/dict90.txt
@@ -0,0 +1,90 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+:
+;
+<
+=
+>
+?
+@
+[
+\
+]
+_
+`
+~
\ No newline at end of file
diff --git a/tools/utils/e2e_metric/Deteval.py b/tools/utils/e2e_metric/Deteval.py
new file mode 100644
index 0000000000000000000000000000000000000000..f53588c719ea68f93cd7ff3152696e8e557979e9
--- /dev/null
+++ b/tools/utils/e2e_metric/Deteval.py
@@ -0,0 +1,802 @@
+import json
+import numpy as np
+import scipy.io as io
+
+from tools.utils.utility import check_install
+
+from tools.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
+
+
+def get_socre_A(gt_dir, pred_dict):
+ allInputs = 1
+
+ def input_reading_mod(pred_dict):
+ """This helper reads input from txt files"""
+ det = []
+ n = len(pred_dict)
+ for i in range(n):
+ points = pred_dict[i]["points"]
+ text = pred_dict[i]["texts"]
+ point = ",".join(map(
+ str,
+ points.reshape(-1, ), ))
+ det.append([point, text])
+ return det
+
+ def gt_reading_mod(gt_dict):
+ """This helper reads groundtruths from mat files"""
+ gt = []
+ n = len(gt_dict)
+ for i in range(n):
+ points = gt_dict[i]["points"].tolist()
+ h = len(points)
+ text = gt_dict[i]["text"]
+ xx = [
+ np.array(
+ ["x:"], dtype=" 1):
+ gt_x = list(map(int, np.squeeze(gt[1])))
+ gt_y = list(map(int, np.squeeze(gt[3])))
+ for det_id, detection in enumerate(detections):
+ detection_orig = detection
+ detection = [float(x) for x in detection[0].split(",")]
+ detection = list(map(int, detection))
+ det_x = detection[0::2]
+ det_y = detection[1::2]
+ det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
+ if det_gt_iou > threshold:
+ detections[det_id] = []
+
+ detections[:] = [item for item in detections if item != []]
+ return detections
+
+ def sigma_calculation(det_x, det_y, gt_x, gt_y):
+ """
+ sigma = inter_area / gt_area
+ """
+ return np.round(
+ (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)),
+ 2)
+
+ def tau_calculation(det_x, det_y, gt_x, gt_y):
+ if area(det_x, det_y) == 0.0:
+ return 0
+ return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
+ area(det_x, det_y)), 2)
+
+ ##############################Initialization###################################
+ # global_sigma = []
+ # global_tau = []
+ # global_pred_str = []
+ # global_gt_str = []
+ ###############################################################################
+
+ for input_id in range(allInputs):
+ if ((input_id != ".DS_Store") and (input_id != "Pascal_result.txt") and
+ (input_id != "Pascal_result_curved.txt") and
+ (input_id != "Pascal_result_non_curved.txt") and
+ (input_id != "Deteval_result.txt") and
+ (input_id != "Deteval_result_curved.txt") and
+ (input_id != "Deteval_result_non_curved.txt")):
+ detections = input_reading_mod(pred_dict)
+ groundtruths = gt_reading_mod(gt_dir)
+ detections = detection_filtering(
+ detections,
+ groundtruths) # filters detections overlapping with DC area
+ dc_id = []
+ for i in range(len(groundtruths)):
+ if groundtruths[i][5] == "#":
+ dc_id.append(i)
+ cnt = 0
+ for a in dc_id:
+ num = a - cnt
+ del groundtruths[num]
+ cnt += 1
+
+ local_sigma_table = np.zeros((len(groundtruths), len(detections)))
+ local_tau_table = np.zeros((len(groundtruths), len(detections)))
+ local_pred_str = {}
+ local_gt_str = {}
+
+ for gt_id, gt in enumerate(groundtruths):
+ if len(detections) > 0:
+ for det_id, detection in enumerate(detections):
+ detection_orig = detection
+ detection = [float(x) for x in detection[0].split(",")]
+ detection = list(map(int, detection))
+ pred_seq_str = detection_orig[1].strip()
+ det_x = detection[0::2]
+ det_y = detection[1::2]
+ gt_x = list(map(int, np.squeeze(gt[1])))
+ gt_y = list(map(int, np.squeeze(gt[3])))
+ gt_seq_str = str(gt[4].tolist()[0])
+
+ local_sigma_table[gt_id, det_id] = sigma_calculation(
+ det_x, det_y, gt_x, gt_y)
+ local_tau_table[gt_id, det_id] = tau_calculation(
+ det_x, det_y, gt_x, gt_y)
+ local_pred_str[det_id] = pred_seq_str
+ local_gt_str[gt_id] = gt_seq_str
+
+ global_sigma = local_sigma_table
+ global_tau = local_tau_table
+ global_pred_str = local_pred_str
+ global_gt_str = local_gt_str
+
+ single_data = {}
+ single_data["sigma"] = global_sigma
+ single_data["global_tau"] = global_tau
+ single_data["global_pred_str"] = global_pred_str
+ single_data["global_gt_str"] = global_gt_str
+ return single_data
+
+
+def get_socre_B(gt_dir, img_id, pred_dict):
+ allInputs = 1
+
+ def input_reading_mod(pred_dict):
+ """This helper reads input from txt files"""
+ det = []
+ n = len(pred_dict)
+ for i in range(n):
+ points = pred_dict[i]["points"]
+ text = pred_dict[i]["texts"]
+ point = ",".join(map(
+ str,
+ points.reshape(-1, ), ))
+ det.append([point, text])
+ return det
+
+ def gt_reading_mod(gt_dir, gt_id):
+ gt = io.loadmat("%s/poly_gt_img%s.mat" % (gt_dir, gt_id))
+ gt = gt["polygt"]
+ return gt
+
+ def detection_filtering(detections, groundtruths, threshold=0.5):
+ for gt_id, gt in enumerate(groundtruths):
+ if (gt[5] == "#") and (gt[1].shape[1] > 1):
+ gt_x = list(map(int, np.squeeze(gt[1])))
+ gt_y = list(map(int, np.squeeze(gt[3])))
+ for det_id, detection in enumerate(detections):
+ detection_orig = detection
+ detection = [float(x) for x in detection[0].split(",")]
+ detection = list(map(int, detection))
+ det_x = detection[0::2]
+ det_y = detection[1::2]
+ det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
+ if det_gt_iou > threshold:
+ detections[det_id] = []
+
+ detections[:] = [item for item in detections if item != []]
+ return detections
+
+ def sigma_calculation(det_x, det_y, gt_x, gt_y):
+ """
+ sigma = inter_area / gt_area
+ """
+ return np.round(
+ (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)),
+ 2)
+
+ def tau_calculation(det_x, det_y, gt_x, gt_y):
+ if area(det_x, det_y) == 0.0:
+ return 0
+ return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
+ area(det_x, det_y)), 2)
+
+ ##############################Initialization###################################
+ # global_sigma = []
+ # global_tau = []
+ # global_pred_str = []
+ # global_gt_str = []
+ ###############################################################################
+
+ for input_id in range(allInputs):
+ if ((input_id != ".DS_Store") and (input_id != "Pascal_result.txt") and
+ (input_id != "Pascal_result_curved.txt") and
+ (input_id != "Pascal_result_non_curved.txt") and
+ (input_id != "Deteval_result.txt") and
+ (input_id != "Deteval_result_curved.txt") and
+ (input_id != "Deteval_result_non_curved.txt")):
+ detections = input_reading_mod(pred_dict)
+ groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
+ detections = detection_filtering(
+ detections,
+ groundtruths) # filters detections overlapping with DC area
+ dc_id = []
+ for i in range(len(groundtruths)):
+ if groundtruths[i][5] == "#":
+ dc_id.append(i)
+ cnt = 0
+ for a in dc_id:
+ num = a - cnt
+ del groundtruths[num]
+ cnt += 1
+
+ local_sigma_table = np.zeros((len(groundtruths), len(detections)))
+ local_tau_table = np.zeros((len(groundtruths), len(detections)))
+ local_pred_str = {}
+ local_gt_str = {}
+
+ for gt_id, gt in enumerate(groundtruths):
+ if len(detections) > 0:
+ for det_id, detection in enumerate(detections):
+ detection_orig = detection
+ detection = [float(x) for x in detection[0].split(",")]
+ detection = list(map(int, detection))
+ pred_seq_str = detection_orig[1].strip()
+ det_x = detection[0::2]
+ det_y = detection[1::2]
+ gt_x = list(map(int, np.squeeze(gt[1])))
+ gt_y = list(map(int, np.squeeze(gt[3])))
+ gt_seq_str = str(gt[4].tolist()[0])
+
+ local_sigma_table[gt_id, det_id] = sigma_calculation(
+ det_x, det_y, gt_x, gt_y)
+ local_tau_table[gt_id, det_id] = tau_calculation(
+ det_x, det_y, gt_x, gt_y)
+ local_pred_str[det_id] = pred_seq_str
+ local_gt_str[gt_id] = gt_seq_str
+
+ global_sigma = local_sigma_table
+ global_tau = local_tau_table
+ global_pred_str = local_pred_str
+ global_gt_str = local_gt_str
+
+ single_data = {}
+ single_data["sigma"] = global_sigma
+ single_data["global_tau"] = global_tau
+ single_data["global_pred_str"] = global_pred_str
+ single_data["global_gt_str"] = global_gt_str
+ return single_data
+
+
+def get_score_C(gt_label, text, pred_bboxes):
+ """
+ get score for CentripetalText (CT) prediction.
+ """
+ check_install("Polygon", "Polygon3")
+ import Polygon as plg
+
+ def gt_reading_mod(gt_label, text):
+ """This helper reads groundtruths from mat files"""
+ groundtruths = []
+ nbox = len(gt_label)
+ for i in range(nbox):
+ label = {"transcription": text[i][0], "points": gt_label[i].numpy()}
+ groundtruths.append(label)
+
+ return groundtruths
+
+ def get_union(pD, pG):
+ areaA = pD.area()
+ areaB = pG.area()
+ return areaA + areaB - get_intersection(pD, pG)
+
+ def get_intersection(pD, pG):
+ pInt = pD & pG
+ if len(pInt) == 0:
+ return 0
+ return pInt.area()
+
+ def detection_filtering(detections, groundtruths, threshold=0.5):
+ for gt in groundtruths:
+ point_num = gt["points"].shape[1] // 2
+ if gt["transcription"] == "###" and (point_num > 1):
+ gt_p = np.array(gt["points"]).reshape(point_num,
+ 2).astype("int32")
+ gt_p = plg.Polygon(gt_p)
+
+ for det_id, detection in enumerate(detections):
+ det_y = detection[0::2]
+ det_x = detection[1::2]
+
+ det_p = np.concatenate((np.array(det_x), np.array(det_y)))
+ det_p = det_p.reshape(2, -1).transpose()
+ det_p = plg.Polygon(det_p)
+
+ try:
+ det_gt_iou = get_intersection(det_p,
+ gt_p) / det_p.area()
+ except:
+ print(det_x, det_y, gt_p)
+ if det_gt_iou > threshold:
+ detections[det_id] = []
+
+ detections[:] = [item for item in detections if item != []]
+ return detections
+
+ def sigma_calculation(det_p, gt_p):
+ """
+ sigma = inter_area / gt_area
+ """
+ if gt_p.area() == 0.0:
+ return 0
+ return get_intersection(det_p, gt_p) / gt_p.area()
+
+ def tau_calculation(det_p, gt_p):
+ """
+ tau = inter_area / det_area
+ """
+ if det_p.area() == 0.0:
+ return 0
+ return get_intersection(det_p, gt_p) / det_p.area()
+
+ detections = []
+
+ for item in pred_bboxes:
+ detections.append(item[:, ::-1].reshape(-1))
+
+ groundtruths = gt_reading_mod(gt_label, text)
+
+ detections = detection_filtering(
+ detections, groundtruths) # filters detections overlapping with DC area
+
+ for idx in range(len(groundtruths) - 1, -1, -1):
+ # NOTE: source code use 'orin' to indicate '#', here we use 'anno',
+ # which may cause slight drop in fscore, about 0.12
+ if groundtruths[idx]["transcription"] == "###":
+ groundtruths.pop(idx)
+
+ local_sigma_table = np.zeros((len(groundtruths), len(detections)))
+ local_tau_table = np.zeros((len(groundtruths), len(detections)))
+
+ for gt_id, gt in enumerate(groundtruths):
+ if len(detections) > 0:
+ for det_id, detection in enumerate(detections):
+ point_num = gt["points"].shape[1] // 2
+
+ gt_p = np.array(gt["points"]).reshape(point_num,
+ 2).astype("int32")
+ gt_p = plg.Polygon(gt_p)
+
+ det_y = detection[0::2]
+ det_x = detection[1::2]
+
+ det_p = np.concatenate((np.array(det_x), np.array(det_y)))
+
+ det_p = det_p.reshape(2, -1).transpose()
+ det_p = plg.Polygon(det_p)
+
+ local_sigma_table[gt_id, det_id] = sigma_calculation(det_p,
+ gt_p)
+ local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p)
+
+ data = {}
+ data["sigma"] = local_sigma_table
+ data["global_tau"] = local_tau_table
+ data["global_pred_str"] = ""
+ data["global_gt_str"] = ""
+ return data
+
+
+def combine_results(all_data, rec_flag=True):
+ tr = 0.7
+ tp = 0.6
+ fsc_k = 0.8
+ k = 2
+ global_sigma = []
+ global_tau = []
+ global_pred_str = []
+ global_gt_str = []
+
+ for data in all_data:
+ global_sigma.append(data["sigma"])
+ global_tau.append(data["global_tau"])
+ global_pred_str.append(data["global_pred_str"])
+ global_gt_str.append(data["global_gt_str"])
+
+ global_accumulative_recall = 0
+ global_accumulative_precision = 0
+ total_num_gt = 0
+ total_num_det = 0
+ hit_str_count = 0
+ hit_count = 0
+
+ def one_to_one(
+ local_sigma_table,
+ local_tau_table,
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ idy,
+ rec_flag, ):
+ hit_str_num = 0
+ for gt_id in range(num_gt):
+ gt_matching_qualified_sigma_candidates = np.where(
+ local_sigma_table[gt_id, :] > tr)
+ gt_matching_num_qualified_sigma_candidates = (
+ gt_matching_qualified_sigma_candidates[0].shape[0])
+ gt_matching_qualified_tau_candidates = np.where(
+ local_tau_table[gt_id, :] > tp)
+ gt_matching_num_qualified_tau_candidates = (
+ gt_matching_qualified_tau_candidates[0].shape[0])
+
+ det_matching_qualified_sigma_candidates = np.where(
+ local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
+ > tr)
+ det_matching_num_qualified_sigma_candidates = (
+ det_matching_qualified_sigma_candidates[0].shape[0])
+ det_matching_qualified_tau_candidates = np.where(
+ local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
+ tp)
+ det_matching_num_qualified_tau_candidates = (
+ det_matching_qualified_tau_candidates[0].shape[0])
+
+ if ((gt_matching_num_qualified_sigma_candidates == 1) and
+ (gt_matching_num_qualified_tau_candidates == 1) and
+ (det_matching_num_qualified_sigma_candidates == 1) and
+ (det_matching_num_qualified_tau_candidates == 1)):
+ global_accumulative_recall = global_accumulative_recall + 1.0
+ global_accumulative_precision = global_accumulative_precision + 1.0
+ local_accumulative_recall = local_accumulative_recall + 1.0
+ local_accumulative_precision = local_accumulative_precision + 1.0
+
+ gt_flag[0, gt_id] = 1
+ matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
+ # recg start
+ if rec_flag:
+ gt_str_cur = global_gt_str[idy][gt_id]
+ pred_str_cur = global_pred_str[idy][matched_det_id[0]
+ .tolist()[0]]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
+ # recg end
+ det_flag[0, matched_det_id] = 1
+ return (
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ hit_str_num, )
+
+ def one_to_many(
+ local_sigma_table,
+ local_tau_table,
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ idy,
+ rec_flag, ):
+ hit_str_num = 0
+ for gt_id in range(num_gt):
+ # skip the following if the groundtruth was matched
+ if gt_flag[0, gt_id] > 0:
+ continue
+
+ non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
+ num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
+
+ if num_non_zero_in_sigma >= k:
+ ####search for all detections that overlaps with this groundtruth
+ qualified_tau_candidates = np.where((local_tau_table[
+ gt_id, :] >= tp) & (det_flag[0, :] == 0))
+ num_qualified_tau_candidates = qualified_tau_candidates[
+ 0].shape[0]
+
+ if num_qualified_tau_candidates == 1:
+ if (local_tau_table[gt_id, qualified_tau_candidates] >= tp
+ ) and (
+ local_sigma_table[gt_id, qualified_tau_candidates]
+ >= tr):
+ # became an one-to-one case
+ global_accumulative_recall = global_accumulative_recall + 1.0
+ global_accumulative_precision = (
+ global_accumulative_precision + 1.0)
+ local_accumulative_recall = local_accumulative_recall + 1.0
+ local_accumulative_precision = (
+ local_accumulative_precision + 1.0)
+
+ gt_flag[0, gt_id] = 1
+ det_flag[0, qualified_tau_candidates] = 1
+ # recg start
+ if rec_flag:
+ gt_str_cur = global_gt_str[idy][gt_id]
+ pred_str_cur = global_pred_str[idy][
+ qualified_tau_candidates[0].tolist()[0]]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
+ # recg end
+ elif np.sum(local_sigma_table[gt_id,
+ qualified_tau_candidates]) >= tr:
+ gt_flag[0, gt_id] = 1
+ det_flag[0, qualified_tau_candidates] = 1
+ # recg start
+ if rec_flag:
+ gt_str_cur = global_gt_str[idy][gt_id]
+ pred_str_cur = global_pred_str[idy][
+ qualified_tau_candidates[0].tolist()[0]]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
+ # recg end
+
+ global_accumulative_recall = global_accumulative_recall + fsc_k
+ global_accumulative_precision = (
+ global_accumulative_precision +
+ num_qualified_tau_candidates * fsc_k)
+
+ local_accumulative_recall = local_accumulative_recall + fsc_k
+ local_accumulative_precision = (
+ local_accumulative_precision +
+ num_qualified_tau_candidates * fsc_k)
+
+ return (
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ hit_str_num, )
+
+ def many_to_one(
+ local_sigma_table,
+ local_tau_table,
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ idy,
+ rec_flag, ):
+ hit_str_num = 0
+ for det_id in range(num_det):
+ # skip the following if the detection was matched
+ if det_flag[0, det_id] > 0:
+ continue
+
+ non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
+ num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
+
+ if num_non_zero_in_tau >= k:
+ ####search for all detections that overlaps with this groundtruth
+ qualified_sigma_candidates = np.where((
+ local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
+ num_qualified_sigma_candidates = qualified_sigma_candidates[
+ 0].shape[0]
+
+ if num_qualified_sigma_candidates == 1:
+ if (
+ local_tau_table[qualified_sigma_candidates, det_id]
+ >= tp
+ ) and (local_sigma_table[qualified_sigma_candidates, det_id]
+ >= tr):
+ # became an one-to-one case
+ global_accumulative_recall = global_accumulative_recall + 1.0
+ global_accumulative_precision = (
+ global_accumulative_precision + 1.0)
+ local_accumulative_recall = local_accumulative_recall + 1.0
+ local_accumulative_precision = (
+ local_accumulative_precision + 1.0)
+
+ gt_flag[0, qualified_sigma_candidates] = 1
+ det_flag[0, det_id] = 1
+ # recg start
+ if rec_flag:
+ pred_str_cur = global_pred_str[idy][det_id]
+ gt_len = len(qualified_sigma_candidates[0])
+ for idx in range(gt_len):
+ ele_gt_id = qualified_sigma_candidates[
+ 0].tolist()[idx]
+ if ele_gt_id not in global_gt_str[idy]:
+ continue
+ gt_str_cur = global_gt_str[idy][ele_gt_id]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ break
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower(
+ ):
+ hit_str_num += 1
+ break
+ # recg end
+ elif np.sum(local_tau_table[qualified_sigma_candidates,
+ det_id]) >= tp:
+ det_flag[0, det_id] = 1
+ gt_flag[0, qualified_sigma_candidates] = 1
+ # recg start
+ if rec_flag:
+ pred_str_cur = global_pred_str[idy][det_id]
+ gt_len = len(qualified_sigma_candidates[0])
+ for idx in range(gt_len):
+ ele_gt_id = qualified_sigma_candidates[0].tolist()[
+ idx]
+ if ele_gt_id not in global_gt_str[idy]:
+ continue
+ gt_str_cur = global_gt_str[idy][ele_gt_id]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ break
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
+ break
+ # recg end
+
+ global_accumulative_recall = (
+ global_accumulative_recall +
+ num_qualified_sigma_candidates * fsc_k)
+ global_accumulative_precision = (
+ global_accumulative_precision + fsc_k)
+
+ local_accumulative_recall = (
+ local_accumulative_recall +
+ num_qualified_sigma_candidates * fsc_k)
+ local_accumulative_precision = local_accumulative_precision + fsc_k
+ return (
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ hit_str_num, )
+
+ for idx in range(len(global_sigma)):
+ local_sigma_table = np.array(global_sigma[idx])
+ local_tau_table = global_tau[idx]
+
+ num_gt = local_sigma_table.shape[0]
+ num_det = local_sigma_table.shape[1]
+
+ total_num_gt = total_num_gt + num_gt
+ total_num_det = total_num_det + num_det
+
+ local_accumulative_recall = 0
+ local_accumulative_precision = 0
+ gt_flag = np.zeros((1, num_gt))
+ det_flag = np.zeros((1, num_det))
+
+ #######first check for one-to-one case##########
+ (
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ hit_str_num, ) = one_to_one(
+ local_sigma_table,
+ local_tau_table,
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ idx,
+ rec_flag, )
+
+ hit_str_count += hit_str_num
+ #######then check for one-to-many case##########
+ (
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ hit_str_num, ) = one_to_many(
+ local_sigma_table,
+ local_tau_table,
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ idx,
+ rec_flag, )
+ hit_str_count += hit_str_num
+ #######then check for many-to-one case##########
+ (
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ hit_str_num, ) = many_to_one(
+ local_sigma_table,
+ local_tau_table,
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ idx,
+ rec_flag, )
+ hit_str_count += hit_str_num
+
+ try:
+ recall = global_accumulative_recall / total_num_gt
+ except ZeroDivisionError:
+ recall = 0
+
+ try:
+ precision = global_accumulative_precision / total_num_det
+ except ZeroDivisionError:
+ precision = 0
+
+ try:
+ f_score = 2 * precision * recall / (precision + recall)
+ except ZeroDivisionError:
+ f_score = 0
+
+ try:
+ seqerr = 1 - float(hit_str_count) / global_accumulative_recall
+ except ZeroDivisionError:
+ seqerr = 1
+
+ try:
+ recall_e2e = float(hit_str_count) / total_num_gt
+ except ZeroDivisionError:
+ recall_e2e = 0
+
+ try:
+ precision_e2e = float(hit_str_count) / total_num_det
+ except ZeroDivisionError:
+ precision_e2e = 0
+
+ try:
+ f_score_e2e = 2 * precision_e2e * recall_e2e / (
+ precision_e2e + recall_e2e)
+ except ZeroDivisionError:
+ f_score_e2e = 0
+
+ final = {
+ "total_num_gt": total_num_gt,
+ "total_num_det": total_num_det,
+ "global_accumulative_recall": global_accumulative_recall,
+ "hit_str_count": hit_str_count,
+ "recall": recall,
+ "precision": precision,
+ "f_score": f_score,
+ "seqerr": seqerr,
+ "recall_e2e": recall_e2e,
+ "precision_e2e": precision_e2e,
+ "f_score_e2e": f_score_e2e,
+ }
+ return final
diff --git a/tools/utils/e2e_metric/polygon_fast.py b/tools/utils/e2e_metric/polygon_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..794b4f685c9a6cc79c83b8253c92bde7afc514b8
--- /dev/null
+++ b/tools/utils/e2e_metric/polygon_fast.py
@@ -0,0 +1,70 @@
+import numpy as np
+from shapely.geometry import Polygon
+"""
+:param det_x: [1, N] Xs of detection's vertices
+:param det_y: [1, N] Ys of detection's vertices
+:param gt_x: [1, N] Xs of groundtruth's vertices
+:param gt_y: [1, N] Ys of groundtruth's vertices
+
+##############
+All the calculation of 'AREA' in this script is handled by:
+1) First generating a binary mask with the polygon area filled up with 1's
+2) Summing up all the 1's
+"""
+
+
+def area(x, y):
+ polygon = Polygon(np.stack([x, y], axis=1))
+ return float(polygon.area)
+
+
+def approx_area_of_intersection(det_x, det_y, gt_x, gt_y):
+ """
+ This helper determine if both polygons are intersecting with each others with an approximation method.
+ Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax]
+ """
+ det_ymax = np.max(det_y)
+ det_xmax = np.max(det_x)
+ det_ymin = np.min(det_y)
+ det_xmin = np.min(det_x)
+
+ gt_ymax = np.max(gt_y)
+ gt_xmax = np.max(gt_x)
+ gt_ymin = np.min(gt_y)
+ gt_xmin = np.min(gt_x)
+
+ all_min_ymax = np.minimum(det_ymax, gt_ymax)
+ all_max_ymin = np.maximum(det_ymin, gt_ymin)
+
+ intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin))
+
+ all_min_xmax = np.minimum(det_xmax, gt_xmax)
+ all_max_xmin = np.maximum(det_xmin, gt_xmin)
+ intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin))
+
+ return intersect_heights * intersect_widths
+
+
+def area_of_intersection(det_x, det_y, gt_x, gt_y):
+ p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
+ p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
+ return float(p1.intersection(p2).area)
+
+
+def area_of_union(det_x, det_y, gt_x, gt_y):
+ p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
+ p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
+ return float(p1.union(p2).area)
+
+
+def iou(det_x, det_y, gt_x, gt_y):
+ return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
+ area_of_union(det_x, det_y, gt_x, gt_y) + 1.0)
+
+
+def iod(det_x, det_y, gt_x, gt_y):
+ """
+ This helper determine the fraction of intersection area over detection area
+ """
+ return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
+ area(det_x, det_y) + 1.0)
diff --git a/tools/utils/e2e_utils/extract_batchsize.py b/tools/utils/e2e_utils/extract_batchsize.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e99a987692a133fe117cb803115d75f5839b94f
--- /dev/null
+++ b/tools/utils/e2e_utils/extract_batchsize.py
@@ -0,0 +1,86 @@
+import torch
+import numpy as np
+import copy
+
+
+def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs):
+ """ """
+ pos_lists_, pos_masks_, label_lists_ = [], [], []
+ img_bs = batch_size
+ ngpu = int(batch_size / img_bs)
+ img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
+ pos_lists_split, pos_masks_split, label_lists_split = [], [], []
+ for i in range(ngpu):
+ pos_lists_split.append([])
+ pos_masks_split.append([])
+ label_lists_split.append([])
+
+ for i in range(img_ids.shape[0]):
+ img_id = img_ids[i]
+ gpu_id = int(img_id / img_bs)
+ img_id = img_id % img_bs
+ pos_list = pos_lists[i].copy()
+ pos_list[:, 0] = img_id
+ pos_lists_split[gpu_id].append(pos_list)
+ pos_masks_split[gpu_id].append(pos_masks[i].copy())
+ label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
+ # repeat or delete
+ for i in range(ngpu):
+ vp_len = len(pos_lists_split[i])
+ if vp_len <= tcl_bs:
+ for j in range(0, tcl_bs - vp_len):
+ pos_list = pos_lists_split[i][j].copy()
+ pos_lists_split[i].append(pos_list)
+ pos_mask = pos_masks_split[i][j].copy()
+ pos_masks_split[i].append(pos_mask)
+ label_list = copy.deepcopy(label_lists_split[i][j])
+ label_lists_split[i].append(label_list)
+ else:
+ for j in range(0, vp_len - tcl_bs):
+ c_len = len(pos_lists_split[i])
+ pop_id = np.random.permutation(c_len)[0]
+ pos_lists_split[i].pop(pop_id)
+ pos_masks_split[i].pop(pop_id)
+ label_lists_split[i].pop(pop_id)
+ # merge
+ for i in range(ngpu):
+ pos_lists_.extend(pos_lists_split[i])
+ pos_masks_.extend(pos_masks_split[i])
+ label_lists_.extend(label_lists_split[i])
+ return pos_lists_, pos_masks_, label_lists_
+
+
+def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
+ pad_num, tcl_bs):
+ label_list = label_list.numpy()
+ batch, _, _, _ = label_list.shape
+ pos_list = pos_list.numpy()
+ pos_mask = pos_mask.numpy()
+ pos_list_t = []
+ pos_mask_t = []
+ label_list_t = []
+ for i in range(batch):
+ for j in range(max_text_nums):
+ if pos_mask[i, j].any():
+ pos_list_t.append(pos_list[i][j])
+ pos_mask_t.append(pos_mask[i][j])
+ label_list_t.append(label_list[i][j])
+ pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t,
+ label_list_t, tcl_bs)
+ label = []
+ tt = [l.tolist() for l in label_list]
+ for i in range(tcl_bs):
+ k = 0
+ for j in range(max_text_length):
+ if tt[i][j][0] != pad_num:
+ k += 1
+ else:
+ break
+ label.append(k)
+ label = torch.tensor(label)
+ label = label.long()
+ pos_list = torch.tensor(pos_list)
+ pos_mask = torch.tensor(pos_mask)
+ label_list = torch.squeeze(torch.tensor(label_list), dim=2)
+ label_list = label_list.int()
+ return pos_list, pos_mask, label_list, label
diff --git a/tools/utils/e2e_utils/extract_textpoint_fast.py b/tools/utils/e2e_utils/extract_textpoint_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..459e66f65089d9c628ccec5d3e7026ae53179987
--- /dev/null
+++ b/tools/utils/e2e_utils/extract_textpoint_fast.py
@@ -0,0 +1,479 @@
+import cv2
+import math
+
+import numpy as np
+from itertools import groupby
+from skimage.morphology._skeletonize import thin
+
+
+def get_dict(character_dict_path):
+ character_str = ""
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
+ character_str += line
+ dict_character = list(character_str)
+ return dict_character
+
+
+def softmax(logits):
+ """
+ logits: N x d
+ """
+ max_value = np.max(logits, axis=1, keepdims=True)
+ exp = np.exp(logits - max_value)
+ exp_sum = np.sum(exp, axis=1, keepdims=True)
+ dist = exp / exp_sum
+ return dist
+
+
+def get_keep_pos_idxs(labels, remove_blank=None):
+ """
+ Remove duplicate and get pos idxs of keep items.
+ The value of keep_blank should be [None, 95].
+ """
+ duplicate_len_list = []
+ keep_pos_idx_list = []
+ keep_char_idx_list = []
+ for k, v_ in groupby(labels):
+ current_len = len(list(v_))
+ if k != remove_blank:
+ current_idx = int(sum(duplicate_len_list) + current_len // 2)
+ keep_pos_idx_list.append(current_idx)
+ keep_char_idx_list.append(k)
+ duplicate_len_list.append(current_len)
+ return keep_char_idx_list, keep_pos_idx_list
+
+
+def remove_blank(labels, blank=0):
+ new_labels = [x for x in labels if x != blank]
+ return new_labels
+
+
+def insert_blank(labels, blank=0):
+ new_labels = [blank]
+ for l in labels:
+ new_labels += [l, blank]
+ return new_labels
+
+
+def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
+ """
+ CTC greedy (best path) decoder.
+ """
+ raw_str = np.argmax(np.array(probs_seq), axis=1)
+ remove_blank_in_pos = None if keep_blank_in_idxs else blank
+ dedup_str, keep_idx_list = get_keep_pos_idxs(
+ raw_str, remove_blank=remove_blank_in_pos)
+ dst_str = remove_blank(dedup_str, blank=blank)
+ return dst_str, keep_idx_list
+
+
+def instance_ctc_greedy_decoder(gather_info,
+ logits_map,
+ pts_num=4,
+ point_gather_mode=None):
+ _, _, C = logits_map.shape
+ if point_gather_mode == "align":
+ insert_num = 0
+ gather_info = np.array(gather_info)
+ length = len(gather_info) - 1
+ for index in range(length):
+ stride_y = np.abs(gather_info[index + insert_num][0] - gather_info[
+ index + 1 + insert_num][0])
+ stride_x = np.abs(gather_info[index + insert_num][1] - gather_info[
+ index + 1 + insert_num][1])
+ max_points = int(max(stride_x, stride_y))
+ stride = (gather_info[index + insert_num] -
+ gather_info[index + 1 + insert_num]) / (max_points)
+ insert_num_temp = max_points - 1
+
+ for i in range(int(insert_num_temp)):
+ insert_value = gather_info[index + insert_num] - (i + 1
+ ) * stride
+ insert_index = index + i + 1 + insert_num
+ gather_info = np.insert(
+ gather_info, insert_index, insert_value, axis=0)
+ insert_num += insert_num_temp
+ gather_info = gather_info.tolist()
+ else:
+ pass
+ ys, xs = zip(*gather_info)
+ logits_seq = logits_map[list(ys), list(xs)]
+ probs_seq = logits_seq
+ labels = np.argmax(probs_seq, axis=1)
+ dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
+ detal = len(gather_info) // (pts_num - 1)
+ keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
+ keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
+ return dst_str, keep_gather_list
+
+
+def ctc_decoder_for_image(gather_info_list,
+ logits_map,
+ Lexicon_Table,
+ pts_num=6,
+ point_gather_mode=None):
+ """
+ CTC decoder using multiple processes.
+ """
+ decoder_str = []
+ decoder_xys = []
+ for gather_info in gather_info_list:
+ if len(gather_info) < pts_num:
+ continue
+ dst_str, xys_list = instance_ctc_greedy_decoder(
+ gather_info,
+ logits_map,
+ pts_num=pts_num,
+ point_gather_mode=point_gather_mode, )
+ dst_str_readable = "".join([Lexicon_Table[idx] for idx in dst_str])
+ if len(dst_str_readable) < 2:
+ continue
+ decoder_str.append(dst_str_readable)
+ decoder_xys.append(xys_list)
+ return decoder_str, decoder_xys
+
+
+def sort_with_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+
+ def sort_part_with_direction(pos_list, point_direction):
+ pos_list = np.array(pos_list).reshape(-1, 2)
+ point_direction = np.array(point_direction).reshape(-1, 2)
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list, sorted_direction
+
+ pos_list = np.array(pos_list).reshape(-1, 2)
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list,
+ point_direction)
+
+ point_num = len(sorted_point)
+ if point_num >= 16:
+ middle_num = point_num // 2
+ first_part_point = sorted_point[:middle_num]
+ first_point_direction = sorted_direction[:middle_num]
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
+ first_part_point, first_point_direction)
+
+ last_part_point = sorted_point[middle_num:]
+ last_point_direction = sorted_direction[middle_num:]
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
+ last_part_point, last_point_direction)
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
+
+ return sorted_point, np.array(sorted_direction)
+
+
+def add_id(pos_list, image_id=0):
+ """
+ Add id for gather feature, for inference.
+ """
+ new_list = []
+ for item in pos_list:
+ new_list.append((image_id, item[0], item[1]))
+ return new_list
+
+
+def sort_and_expand_with_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+ h, w, _ = f_direction.shape
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
+
+ point_num = len(sorted_list)
+ sub_direction_len = max(point_num // 3, 2)
+ left_direction = point_direction[:sub_direction_len, :]
+ right_dirction = point_direction[point_num - sub_direction_len:, :]
+
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
+ left_average_len = np.linalg.norm(left_average_direction)
+ left_start = np.array(sorted_list[0])
+ left_step = left_average_direction / (left_average_len + 1e-6)
+
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
+ right_average_len = np.linalg.norm(right_average_direction)
+ right_step = right_average_direction / (right_average_len + 1e-6)
+ right_start = np.array(sorted_list[-1])
+
+ append_num = max(
+ int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ left_list = []
+ right_list = []
+ for i in range(append_num):
+ ly, lx = (np.round(left_start + left_step * (i + 1)).flatten()
+ .astype("int32").tolist())
+ if ly < h and lx < w and (ly, lx) not in left_list:
+ left_list.append((ly, lx))
+ ry, rx = (np.round(right_start + right_step * (i + 1)).flatten()
+ .astype("int32").tolist())
+ if ry < h and rx < w and (ry, rx) not in right_list:
+ right_list.append((ry, rx))
+
+ all_list = left_list[::-1] + sorted_list + right_list
+ return all_list
+
+
+def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ binary_tcl_map: h x w
+ """
+ h, w, _ = f_direction.shape
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
+
+ point_num = len(sorted_list)
+ sub_direction_len = max(point_num // 3, 2)
+ left_direction = point_direction[:sub_direction_len, :]
+ right_dirction = point_direction[point_num - sub_direction_len:, :]
+
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
+ left_average_len = np.linalg.norm(left_average_direction)
+ left_start = np.array(sorted_list[0])
+ left_step = left_average_direction / (left_average_len + 1e-6)
+
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
+ right_average_len = np.linalg.norm(right_average_direction)
+ right_step = right_average_direction / (right_average_len + 1e-6)
+ right_start = np.array(sorted_list[-1])
+
+ append_num = max(
+ int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ max_append_num = 2 * append_num
+
+ left_list = []
+ right_list = []
+ for i in range(max_append_num):
+ ly, lx = (np.round(left_start + left_step * (i + 1)).flatten()
+ .astype("int32").tolist())
+ if ly < h and lx < w and (ly, lx) not in left_list:
+ if binary_tcl_map[ly, lx] > 0.5:
+ left_list.append((ly, lx))
+ else:
+ break
+
+ for i in range(max_append_num):
+ ry, rx = (np.round(right_start + right_step * (i + 1)).flatten()
+ .astype("int32").tolist())
+ if ry < h and rx < w and (ry, rx) not in right_list:
+ if binary_tcl_map[ry, rx] > 0.5:
+ right_list.append((ry, rx))
+ else:
+ break
+
+ all_list = left_list[::-1] + sorted_list + right_list
+ return all_list
+
+
+def point_pair2poly(point_pair_list):
+ """
+ Transfer vertical point_pairs into poly point in clockwise.
+ """
+ point_num = len(point_pair_list) * 2
+ point_list = [0] * point_num
+ for idx, point_pair in enumerate(point_pair_list):
+ point_list[idx] = point_pair[0]
+ point_list[point_num - 1 - idx] = point_pair[1]
+ return np.array(point_list).reshape(-1, 2)
+
+
+def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
+ ratio_pair = np.array(
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
+
+
+def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
+ """
+ expand poly along width.
+ """
+ point_num = poly.shape[0]
+ left_quad = np.array(
+ [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_ratio = (-shrink_ratio_of_width *
+ np.linalg.norm(left_quad[0] - left_quad[3]) /
+ (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6))
+ left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
+ right_quad = np.array(
+ [
+ poly[point_num // 2 - 2],
+ poly[point_num // 2 - 1],
+ poly[point_num // 2],
+ poly[point_num // 2 + 1],
+ ],
+ dtype=np.float32, )
+ right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[
+ 0] - right_quad[3]) / (np.linalg.norm(right_quad[0] - right_quad[1]) +
+ 1e-6)
+ right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
+ poly[0] = left_quad_expand[0]
+ poly[-1] = left_quad_expand[-1]
+ poly[point_num // 2 - 1] = right_quad_expand[1]
+ poly[point_num // 2] = right_quad_expand[2]
+ return poly
+
+
+def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
+ src_h, valid_set):
+ poly_list = []
+ keep_str_list = []
+ for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
+ if len(keep_str) < 2:
+ print("--> too short, {}".format(keep_str))
+ continue
+
+ offset_expand = 1.0
+ if valid_set == "totaltext":
+ offset_expand = 1.2
+
+ point_pair_list = []
+ for y, x in yx_center_line:
+ offset = p_border[:, y, x].reshape(2, 2) * offset_expand
+ ori_yx = np.array([y, x], dtype=np.float32)
+ point_pair = ((ori_yx + offset)[:, ::-1] * 4.0 /
+ np.array([ratio_w, ratio_h]).reshape(-1, 2))
+ point_pair_list.append(point_pair)
+
+ detected_poly = point_pair2poly(point_pair_list)
+ detected_poly = expand_poly_along_width(
+ detected_poly, shrink_ratio_of_width=0.2)
+ detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
+ detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
+
+ keep_str_list.append(keep_str)
+ if valid_set == "partvgg":
+ middle_point = len(detected_poly) // 2
+ detected_poly = detected_poly[
+ [0, middle_point - 1, middle_point, -1], :]
+ poly_list.append(detected_poly)
+ elif valid_set == "totaltext":
+ poly_list.append(detected_poly)
+ else:
+ print("--> Not supported format.")
+ exit(-1)
+ return poly_list, keep_str_list
+
+
+def generate_pivot_list_fast(
+ p_score,
+ p_char_maps,
+ f_direction,
+ Lexicon_Table,
+ score_thresh=0.5,
+ point_gather_mode=None, ):
+ """
+ return center point and end point of TCL instance; filter with the char maps;
+ """
+ p_score = p_score[0]
+ f_direction = f_direction.transpose(1, 2, 0)
+ p_tcl_map = (p_score > score_thresh) * 1.0
+ skeleton_map = thin(p_tcl_map.astype(np.uint8))
+ instance_count, instance_label_map = cv2.connectedComponents(
+ skeleton_map.astype(np.uint8), connectivity=8)
+
+ # get TCL Instance
+ all_pos_yxs = []
+ if instance_count > 0:
+ for instance_id in range(1, instance_count):
+ pos_list = []
+ ys, xs = np.where(instance_label_map == instance_id)
+ pos_list = list(zip(ys, xs))
+
+ if len(pos_list) < 3:
+ continue
+
+ pos_list_sorted = sort_and_expand_with_direction_v2(
+ pos_list, f_direction, p_tcl_map)
+ all_pos_yxs.append(pos_list_sorted)
+
+ p_char_maps = p_char_maps.transpose([1, 2, 0])
+ decoded_str, keep_yxs_list = ctc_decoder_for_image(
+ all_pos_yxs,
+ logits_map=p_char_maps,
+ Lexicon_Table=Lexicon_Table,
+ point_gather_mode=point_gather_mode, )
+ return keep_yxs_list, decoded_str
+
+
+def extract_main_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+ pos_list = np.array(pos_list)
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ average_direction = average_direction / (
+ np.linalg.norm(average_direction) + 1e-6)
+ return average_direction
+
+
+def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
+ """
+ pos_list_full = np.array(pos_list).reshape(-1, 3)
+ pos_list = pos_list_full[:, 1:]
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list
+
+
+def sort_by_direction_with_image_id(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+
+ def sort_part_with_direction(pos_list_full, point_direction):
+ pos_list_full = np.array(pos_list_full).reshape(-1, 3)
+ pos_list = pos_list_full[:, 1:]
+ point_direction = np.array(point_direction).reshape(-1, 2)
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list, sorted_direction
+
+ pos_list = np.array(pos_list).reshape(-1, 3)
+ point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list,
+ point_direction)
+
+ point_num = len(sorted_point)
+ if point_num >= 16:
+ middle_num = point_num // 2
+ first_part_point = sorted_point[:middle_num]
+ first_point_direction = sorted_direction[:middle_num]
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
+ first_part_point, first_point_direction)
+
+ last_part_point = sorted_point[middle_num:]
+ last_point_direction = sorted_direction[middle_num:]
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
+ last_part_point, last_point_direction)
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
+
+ return sorted_point
diff --git a/tools/utils/e2e_utils/extract_textpoint_slow.py b/tools/utils/e2e_utils/extract_textpoint_slow.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0d449388950056a14ab8b8d2fafb4c20fb0c2d7
--- /dev/null
+++ b/tools/utils/e2e_utils/extract_textpoint_slow.py
@@ -0,0 +1,582 @@
+import cv2
+import math
+
+import numpy as np
+from itertools import groupby
+from skimage.morphology._skeletonize import thin
+
+
+def get_dict(character_dict_path):
+ character_str = ""
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
+ character_str += line
+ dict_character = list(character_str)
+ return dict_character
+
+
+def point_pair2poly(point_pair_list):
+ """
+ Transfer vertical point_pairs into poly point in clockwise.
+ """
+ pair_length_list = []
+ for point_pair in point_pair_list:
+ pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
+ pair_length_list.append(pair_length)
+ pair_length_list = np.array(pair_length_list)
+ pair_info = (
+ pair_length_list.max(),
+ pair_length_list.min(),
+ pair_length_list.mean(), )
+
+ point_num = len(point_pair_list) * 2
+ point_list = [0] * point_num
+ for idx, point_pair in enumerate(point_pair_list):
+ point_list[idx] = point_pair[0]
+ point_list[point_num - 1 - idx] = point_pair[1]
+ return np.array(point_list).reshape(-1, 2), pair_info
+
+
+def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
+ """
+ Generate shrink_quad_along_width.
+ """
+ ratio_pair = np.array(
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
+
+
+def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
+ """
+ expand poly along width.
+ """
+ point_num = poly.shape[0]
+ left_quad = np.array(
+ [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_ratio = (-shrink_ratio_of_width *
+ np.linalg.norm(left_quad[0] - left_quad[3]) /
+ (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6))
+ left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
+ right_quad = np.array(
+ [
+ poly[point_num // 2 - 2],
+ poly[point_num // 2 - 1],
+ poly[point_num // 2],
+ poly[point_num // 2 + 1],
+ ],
+ dtype=np.float32, )
+ right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[
+ 0] - right_quad[3]) / (np.linalg.norm(right_quad[0] - right_quad[1]) +
+ 1e-6)
+ right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
+ poly[0] = left_quad_expand[0]
+ poly[-1] = left_quad_expand[-1]
+ poly[point_num // 2 - 1] = right_quad_expand[1]
+ poly[point_num // 2] = right_quad_expand[2]
+ return poly
+
+
+def softmax(logits):
+ """
+ logits: N x d
+ """
+ max_value = np.max(logits, axis=1, keepdims=True)
+ exp = np.exp(logits - max_value)
+ exp_sum = np.sum(exp, axis=1, keepdims=True)
+ dist = exp / exp_sum
+ return dist
+
+
+def get_keep_pos_idxs(labels, remove_blank=None):
+ """
+ Remove duplicate and get pos idxs of keep items.
+ The value of keep_blank should be [None, 95].
+ """
+ duplicate_len_list = []
+ keep_pos_idx_list = []
+ keep_char_idx_list = []
+ for k, v_ in groupby(labels):
+ current_len = len(list(v_))
+ if k != remove_blank:
+ current_idx = int(sum(duplicate_len_list) + current_len // 2)
+ keep_pos_idx_list.append(current_idx)
+ keep_char_idx_list.append(k)
+ duplicate_len_list.append(current_len)
+ return keep_char_idx_list, keep_pos_idx_list
+
+
+def remove_blank(labels, blank=0):
+ new_labels = [x for x in labels if x != blank]
+ return new_labels
+
+
+def insert_blank(labels, blank=0):
+ new_labels = [blank]
+ for l in labels:
+ new_labels += [l, blank]
+ return new_labels
+
+
+def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
+ """
+ CTC greedy (best path) decoder.
+ """
+ raw_str = np.argmax(np.array(probs_seq), axis=1)
+ remove_blank_in_pos = None if keep_blank_in_idxs else blank
+ dedup_str, keep_idx_list = get_keep_pos_idxs(
+ raw_str, remove_blank=remove_blank_in_pos)
+ dst_str = remove_blank(dedup_str, blank=blank)
+ return dst_str, keep_idx_list
+
+
+def instance_ctc_greedy_decoder(gather_info,
+ logits_map,
+ keep_blank_in_idxs=True):
+ """
+ gather_info: [[x, y], [x, y] ...]
+ logits_map: H x W X (n_chars + 1)
+ """
+ _, _, C = logits_map.shape
+ ys, xs = zip(*gather_info)
+ logits_seq = logits_map[list(ys), list(xs)] # n x 96
+ probs_seq = softmax(logits_seq)
+ dst_str, keep_idx_list = ctc_greedy_decoder(
+ probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
+ keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
+ return dst_str, keep_gather_list
+
+
+def ctc_decoder_for_image(gather_info_list, logits_map,
+ keep_blank_in_idxs=True):
+ """
+ CTC decoder using multiple processes.
+ """
+ decoder_results = []
+ for gather_info in gather_info_list:
+ res = instance_ctc_greedy_decoder(
+ gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
+ decoder_results.append(res)
+ return decoder_results
+
+
+def sort_with_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+
+ def sort_part_with_direction(pos_list, point_direction):
+ pos_list = np.array(pos_list).reshape(-1, 2)
+ point_direction = np.array(point_direction).reshape(-1, 2)
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list, sorted_direction
+
+ pos_list = np.array(pos_list).reshape(-1, 2)
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list,
+ point_direction)
+
+ point_num = len(sorted_point)
+ if point_num >= 16:
+ middle_num = point_num // 2
+ first_part_point = sorted_point[:middle_num]
+ first_point_direction = sorted_direction[:middle_num]
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
+ first_part_point, first_point_direction)
+
+ last_part_point = sorted_point[middle_num:]
+ last_point_direction = sorted_direction[middle_num:]
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
+ last_part_point, last_point_direction)
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
+
+ return sorted_point, np.array(sorted_direction)
+
+
+def add_id(pos_list, image_id=0):
+ """
+ Add id for gather feature, for inference.
+ """
+ new_list = []
+ for item in pos_list:
+ new_list.append((image_id, item[0], item[1]))
+ return new_list
+
+
+def sort_and_expand_with_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+ h, w, _ = f_direction.shape
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
+
+ # expand along
+ point_num = len(sorted_list)
+ sub_direction_len = max(point_num // 3, 2)
+ left_direction = point_direction[:sub_direction_len, :]
+ right_dirction = point_direction[point_num - sub_direction_len:, :]
+
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
+ left_average_len = np.linalg.norm(left_average_direction)
+ left_start = np.array(sorted_list[0])
+ left_step = left_average_direction / (left_average_len + 1e-6)
+
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
+ right_average_len = np.linalg.norm(right_average_direction)
+ right_step = right_average_direction / (right_average_len + 1e-6)
+ right_start = np.array(sorted_list[-1])
+
+ append_num = max(
+ int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ left_list = []
+ right_list = []
+ for i in range(append_num):
+ ly, lx = (np.round(left_start + left_step * (i + 1)).flatten()
+ .astype("int32").tolist())
+ if ly < h and lx < w and (ly, lx) not in left_list:
+ left_list.append((ly, lx))
+ ry, rx = (np.round(right_start + right_step * (i + 1)).flatten()
+ .astype("int32").tolist())
+ if ry < h and rx < w and (ry, rx) not in right_list:
+ right_list.append((ry, rx))
+
+ all_list = left_list[::-1] + sorted_list + right_list
+ return all_list
+
+
+def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ binary_tcl_map: h x w
+ """
+ h, w, _ = f_direction.shape
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
+
+ # expand along
+ point_num = len(sorted_list)
+ sub_direction_len = max(point_num // 3, 2)
+ left_direction = point_direction[:sub_direction_len, :]
+ right_dirction = point_direction[point_num - sub_direction_len:, :]
+
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
+ left_average_len = np.linalg.norm(left_average_direction)
+ left_start = np.array(sorted_list[0])
+ left_step = left_average_direction / (left_average_len + 1e-6)
+
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
+ right_average_len = np.linalg.norm(right_average_direction)
+ right_step = right_average_direction / (right_average_len + 1e-6)
+ right_start = np.array(sorted_list[-1])
+
+ append_num = max(
+ int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ max_append_num = 2 * append_num
+
+ left_list = []
+ right_list = []
+ for i in range(max_append_num):
+ ly, lx = (np.round(left_start + left_step * (i + 1)).flatten()
+ .astype("int32").tolist())
+ if ly < h and lx < w and (ly, lx) not in left_list:
+ if binary_tcl_map[ly, lx] > 0.5:
+ left_list.append((ly, lx))
+ else:
+ break
+
+ for i in range(max_append_num):
+ ry, rx = (np.round(right_start + right_step * (i + 1)).flatten()
+ .astype("int32").tolist())
+ if ry < h and rx < w and (ry, rx) not in right_list:
+ if binary_tcl_map[ry, rx] > 0.5:
+ right_list.append((ry, rx))
+ else:
+ break
+
+ all_list = left_list[::-1] + sorted_list + right_list
+ return all_list
+
+
+def generate_pivot_list_curved(
+ p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_expand=True,
+ is_backbone=False,
+ image_id=0, ):
+ """
+ return center point and end point of TCL instance; filter with the char maps;
+ """
+ p_score = p_score[0]
+ f_direction = f_direction.transpose(1, 2, 0)
+ p_tcl_map = (p_score > score_thresh) * 1.0
+ skeleton_map = thin(p_tcl_map)
+ instance_count, instance_label_map = cv2.connectedComponents(
+ skeleton_map.astype(np.uint8), connectivity=8)
+
+ # get TCL Instance
+ all_pos_yxs = []
+ center_pos_yxs = []
+ end_points_yxs = []
+ instance_center_pos_yxs = []
+ pred_strs = []
+ if instance_count > 0:
+ for instance_id in range(1, instance_count):
+ pos_list = []
+ ys, xs = np.where(instance_label_map == instance_id)
+ pos_list = list(zip(ys, xs))
+
+ ### FIX-ME, eliminate outlier
+ if len(pos_list) < 3:
+ continue
+
+ if is_expand:
+ pos_list_sorted = sort_and_expand_with_direction_v2(
+ pos_list, f_direction, p_tcl_map)
+ else:
+ pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
+ all_pos_yxs.append(pos_list_sorted)
+
+ # use decoder to filter backgroud points.
+ p_char_maps = p_char_maps.transpose([1, 2, 0])
+ decode_res = ctc_decoder_for_image(
+ all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
+ for decoded_str, keep_yxs_list in decode_res:
+ if is_backbone:
+ keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
+ instance_center_pos_yxs.append(keep_yxs_list_with_id)
+ pred_strs.append(decoded_str)
+ else:
+ end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
+ center_pos_yxs.extend(keep_yxs_list)
+
+ if is_backbone:
+ return pred_strs, instance_center_pos_yxs
+ else:
+ return center_pos_yxs, end_points_yxs
+
+
+def generate_pivot_list_horizontal(p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_backbone=False,
+ image_id=0):
+ """
+ return center point and end point of TCL instance; filter with the char maps;
+ """
+ p_score = p_score[0]
+ f_direction = f_direction.transpose(1, 2, 0)
+ p_tcl_map_bi = (p_score > score_thresh) * 1.0
+ instance_count, instance_label_map = cv2.connectedComponents(
+ p_tcl_map_bi.astype(np.uint8), connectivity=8)
+
+ # get TCL Instance
+ all_pos_yxs = []
+ center_pos_yxs = []
+ end_points_yxs = []
+ instance_center_pos_yxs = []
+
+ if instance_count > 0:
+ for instance_id in range(1, instance_count):
+ pos_list = []
+ ys, xs = np.where(instance_label_map == instance_id)
+ pos_list = list(zip(ys, xs))
+
+ ### FIX-ME, eliminate outlier
+ if len(pos_list) < 5:
+ continue
+
+ # add rule here
+ main_direction = extract_main_direction(pos_list,
+ f_direction) # y x
+ reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
+ is_h_angle = abs(np.sum(
+ main_direction * reference_directin)) < math.cos(math.pi / 180 *
+ 70)
+
+ point_yxs = np.array(pos_list)
+ max_y, max_x = np.max(point_yxs, axis=0)
+ min_y, min_x = np.min(point_yxs, axis=0)
+ is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
+
+ pos_list_final = []
+ if is_h_len:
+ xs = np.unique(xs)
+ for x in xs:
+ ys = instance_label_map[:, x].copy().reshape((-1, ))
+ y = int(np.where(ys == instance_id)[0].mean())
+ pos_list_final.append((y, x))
+ else:
+ ys = np.unique(ys)
+ for y in ys:
+ xs = instance_label_map[y, :].copy().reshape((-1, ))
+ x = int(np.where(xs == instance_id)[0].mean())
+ pos_list_final.append((y, x))
+
+ pos_list_sorted, _ = sort_with_direction(pos_list_final,
+ f_direction)
+ all_pos_yxs.append(pos_list_sorted)
+
+ # use decoder to filter backgroud points.
+ p_char_maps = p_char_maps.transpose([1, 2, 0])
+ decode_res = ctc_decoder_for_image(
+ all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
+ for decoded_str, keep_yxs_list in decode_res:
+ if is_backbone:
+ keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
+ instance_center_pos_yxs.append(keep_yxs_list_with_id)
+ else:
+ end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
+ center_pos_yxs.extend(keep_yxs_list)
+
+ if is_backbone:
+ return instance_center_pos_yxs
+ else:
+ return center_pos_yxs, end_points_yxs
+
+
+def generate_pivot_list_slow(
+ p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_backbone=False,
+ is_curved=True,
+ image_id=0, ):
+ """
+ Warp all the function together.
+ """
+ if is_curved:
+ return generate_pivot_list_curved(
+ p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=score_thresh,
+ is_expand=True,
+ is_backbone=is_backbone,
+ image_id=image_id, )
+ else:
+ return generate_pivot_list_horizontal(
+ p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=score_thresh,
+ is_backbone=is_backbone,
+ image_id=image_id, )
+
+
+# for refine module
+def extract_main_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+ pos_list = np.array(pos_list)
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ average_direction = average_direction / (
+ np.linalg.norm(average_direction) + 1e-6)
+ return average_direction
+
+
+def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
+ """
+ pos_list_full = np.array(pos_list).reshape(-1, 3)
+ pos_list = pos_list_full[:, 1:]
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list
+
+
+def sort_by_direction_with_image_id(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+
+ def sort_part_with_direction(pos_list_full, point_direction):
+ pos_list_full = np.array(pos_list_full).reshape(-1, 3)
+ pos_list = pos_list_full[:, 1:]
+ point_direction = np.array(point_direction).reshape(-1, 2)
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list, sorted_direction
+
+ pos_list = np.array(pos_list).reshape(-1, 3)
+ point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list,
+ point_direction)
+
+ point_num = len(sorted_point)
+ if point_num >= 16:
+ middle_num = point_num // 2
+ first_part_point = sorted_point[:middle_num]
+ first_point_direction = sorted_direction[:middle_num]
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
+ first_part_point, first_point_direction)
+
+ last_part_point = sorted_point[middle_num:]
+ last_point_direction = sorted_direction[middle_num:]
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
+ last_part_point, last_point_direction)
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
+
+ return sorted_point
+
+
+def generate_pivot_list_tt_inference(
+ p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_backbone=False,
+ is_curved=True,
+ image_id=0, ):
+ """
+ return center point and end point of TCL instance; filter with the char maps;
+ """
+ p_score = p_score[0]
+ f_direction = f_direction.transpose(1, 2, 0)
+ p_tcl_map = (p_score > score_thresh) * 1.0
+ skeleton_map = thin(p_tcl_map)
+ instance_count, instance_label_map = cv2.connectedComponents(
+ skeleton_map.astype(np.uint8), connectivity=8)
+
+ # get TCL Instance
+ all_pos_yxs = []
+ if instance_count > 0:
+ for instance_id in range(1, instance_count):
+ pos_list = []
+ ys, xs = np.where(instance_label_map == instance_id)
+ pos_list = list(zip(ys, xs))
+ ### FIX-ME, eliminate outlier
+ if len(pos_list) < 3:
+ continue
+ pos_list_sorted = sort_and_expand_with_direction_v2(
+ pos_list, f_direction, p_tcl_map)
+ pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
+ all_pos_yxs.append(pos_list_sorted_with_id)
+ return all_pos_yxs
diff --git a/tools/utils/e2e_utils/pgnet_pp_utils.py b/tools/utils/e2e_utils/pgnet_pp_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac8eb71d7dfc43588c7fa70b6717e8ac4b396390
--- /dev/null
+++ b/tools/utils/e2e_utils/pgnet_pp_utils.py
@@ -0,0 +1,159 @@
+import torch
+import os
+import sys
+
+__dir__ = os.path.dirname(__file__)
+sys.path.append(__dir__)
+sys.path.append(os.path.join(__dir__, ".."))
+from extract_textpoint_slow import *
+from extract_textpoint_fast import generate_pivot_list_fast, restore_poly
+
+
+class PGNet_PostProcess(object):
+ # two different post-process
+ def __init__(
+ self,
+ character_dict_path,
+ valid_set,
+ score_thresh,
+ outs_dict,
+ shape_list,
+ point_gather_mode=None, ):
+ self.Lexicon_Table = get_dict(character_dict_path)
+ self.valid_set = valid_set
+ self.score_thresh = score_thresh
+ self.outs_dict = outs_dict
+ self.shape_list = shape_list
+ self.point_gather_mode = point_gather_mode
+
+ def pg_postprocess_fast(self):
+ p_score = self.outs_dict["f_score"]
+ p_border = self.outs_dict["f_border"]
+ p_char = self.outs_dict["f_char"]
+ p_direction = self.outs_dict["f_direction"]
+ if isinstance(p_score, torch.Tensor):
+ p_score = p_score[0].numpy()
+ p_border = p_border[0].numpy()
+ p_direction = p_direction[0].numpy()
+ p_char = p_char[0].numpy()
+ else:
+ p_score = p_score[0]
+ p_border = p_border[0]
+ p_direction = p_direction[0]
+ p_char = p_char[0]
+
+ src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
+ instance_yxs_list, seq_strs = generate_pivot_list_fast(
+ p_score,
+ p_char,
+ p_direction,
+ self.Lexicon_Table,
+ score_thresh=self.score_thresh,
+ point_gather_mode=self.point_gather_mode, )
+ poly_list, keep_str_list = restore_poly(
+ instance_yxs_list,
+ seq_strs,
+ p_border,
+ ratio_w,
+ ratio_h,
+ src_w,
+ src_h,
+ self.valid_set, )
+ data = {
+ "points": poly_list,
+ "texts": keep_str_list,
+ }
+ return data
+
+ def pg_postprocess_slow(self):
+ p_score = self.outs_dict["f_score"]
+ p_border = self.outs_dict["f_border"]
+ p_char = self.outs_dict["f_char"]
+ p_direction = self.outs_dict["f_direction"]
+ if isinstance(p_score, torch.Tensor):
+ p_score = p_score[0].numpy()
+ p_border = p_border[0].numpy()
+ p_direction = p_direction[0].numpy()
+ p_char = p_char[0].numpy()
+ else:
+ p_score = p_score[0]
+ p_border = p_border[0]
+ p_direction = p_direction[0]
+ p_char = p_char[0]
+ src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
+ is_curved = self.valid_set == "totaltext"
+ char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
+ p_score,
+ p_char,
+ p_direction,
+ score_thresh=self.score_thresh,
+ is_backbone=True,
+ is_curved=is_curved, )
+ seq_strs = []
+ for char_idx_set in char_seq_idx_set:
+ pr_str = "".join([self.Lexicon_Table[pos] for pos in char_idx_set])
+ seq_strs.append(pr_str)
+ poly_list = []
+ keep_str_list = []
+ all_point_list = []
+ all_point_pair_list = []
+ for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
+ if len(yx_center_line) == 1:
+ yx_center_line.append(yx_center_line[-1])
+
+ offset_expand = 1.0
+ if self.valid_set == "totaltext":
+ offset_expand = 1.2
+
+ point_pair_list = []
+ for batch_id, y, x in yx_center_line:
+ offset = p_border[:, y, x].reshape(2, 2)
+ if offset_expand != 1.0:
+ offset_length = np.linalg.norm(
+ offset, axis=1, keepdims=True)
+ expand_length = np.clip(
+ offset_length * (offset_expand - 1),
+ a_min=0.5,
+ a_max=3.0)
+ offset_detal = offset / offset_length * expand_length
+ offset = offset + offset_detal
+ ori_yx = np.array([y, x], dtype=np.float32)
+ point_pair = ((ori_yx + offset)[:, ::-1] * 4.0 /
+ np.array([ratio_w, ratio_h]).reshape(-1, 2))
+ point_pair_list.append(point_pair)
+
+ all_point_list.append([
+ int(round(x * 4.0 / ratio_w)),
+ int(round(y * 4.0 / ratio_h))
+ ])
+ all_point_pair_list.append(point_pair.round().astype(np.int32)
+ .tolist())
+
+ detected_poly, pair_length_info = point_pair2poly(point_pair_list)
+ detected_poly = expand_poly_along_width(
+ detected_poly, shrink_ratio_of_width=0.2)
+ detected_poly[:, 0] = np.clip(
+ detected_poly[:, 0], a_min=0, a_max=src_w)
+ detected_poly[:, 1] = np.clip(
+ detected_poly[:, 1], a_min=0, a_max=src_h)
+
+ if len(keep_str) < 2:
+ continue
+
+ keep_str_list.append(keep_str)
+ detected_poly = np.round(detected_poly).astype("int32")
+ if self.valid_set == "partvgg":
+ middle_point = len(detected_poly) // 2
+ detected_poly = detected_poly[
+ [0, middle_point - 1, middle_point, -1], :]
+ poly_list.append(detected_poly)
+ elif self.valid_set == "totaltext":
+ poly_list.append(detected_poly)
+ else:
+ print("--> Not supported format.")
+ exit(-1)
+ data = {
+ "points": poly_list,
+ "texts": keep_str_list,
+ }
+ return data
diff --git a/tools/utils/e2e_utils/visual.py b/tools/utils/e2e_utils/visual.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2509bdb60751e010d94c3cbfe28313496cfa863
--- /dev/null
+++ b/tools/utils/e2e_utils/visual.py
@@ -0,0 +1,152 @@
+import numpy as np
+import cv2
+import time
+
+
+def resize_image(im, max_side_len=512):
+ """
+ resize image to a size multiple of max_stride which is required by the network
+ :param im: the resized image
+ :param max_side_len: limit of max image size to avoid out of memory in gpu
+ :return: the resized image and the resize ratio
+ """
+ h, w, _ = im.shape
+
+ resize_w = w
+ resize_h = h
+
+ if resize_h > resize_w:
+ ratio = float(max_side_len) / resize_h
+ else:
+ ratio = float(max_side_len) / resize_w
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+
+ return im, (ratio_h, ratio_w)
+
+
+def resize_image_min(im, max_side_len=512):
+ """ """
+ h, w, _ = im.shape
+
+ resize_w = w
+ resize_h = h
+
+ if resize_h < resize_w:
+ ratio = float(max_side_len) / resize_h
+ else:
+ ratio = float(max_side_len) / resize_w
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+ return im, (ratio_h, ratio_w)
+
+
+def resize_image_for_totaltext(im, max_side_len=512):
+ """ """
+ h, w, _ = im.shape
+
+ resize_w = w
+ resize_h = h
+ ratio = 1.25
+ if h * ratio > max_side_len:
+ ratio = float(max_side_len) / resize_h
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+ return im, (ratio_h, ratio_w)
+
+
+def point_pair2poly(point_pair_list):
+ """
+ Transfer vertical point_pairs into poly point in clockwise.
+ """
+ pair_length_list = []
+ for point_pair in point_pair_list:
+ pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
+ pair_length_list.append(pair_length)
+ pair_length_list = np.array(pair_length_list)
+ pair_info = (
+ pair_length_list.max(),
+ pair_length_list.min(),
+ pair_length_list.mean(), )
+
+ point_num = len(point_pair_list) * 2
+ point_list = [0] * point_num
+ for idx, point_pair in enumerate(point_pair_list):
+ point_list[idx] = point_pair[0]
+ point_list[point_num - 1 - idx] = point_pair[1]
+ return np.array(point_list).reshape(-1, 2), pair_info
+
+
+def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
+ """
+ Generate shrink_quad_along_width.
+ """
+ ratio_pair = np.array(
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
+
+
+def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
+ """
+ expand poly along width.
+ """
+ point_num = poly.shape[0]
+ left_quad = np.array(
+ [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_ratio = (-shrink_ratio_of_width *
+ np.linalg.norm(left_quad[0] - left_quad[3]) /
+ (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6))
+ left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
+ right_quad = np.array(
+ [
+ poly[point_num // 2 - 2],
+ poly[point_num // 2 - 1],
+ poly[point_num // 2],
+ poly[point_num // 2 + 1],
+ ],
+ dtype=np.float32, )
+ right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[
+ 0] - right_quad[3]) / (np.linalg.norm(right_quad[0] - right_quad[1]) +
+ 1e-6)
+ right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
+ poly[0] = left_quad_expand[0]
+ poly[-1] = left_quad_expand[-1]
+ poly[point_num // 2 - 1] = right_quad_expand[1]
+ poly[point_num // 2] = right_quad_expand[2]
+ return poly
+
+
+def norm2(x, axis=None):
+ if axis:
+ return np.sqrt(np.sum(x**2, axis=axis))
+ return np.sqrt(np.sum(x**2))
+
+
+def cos(p1, p2):
+ return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
diff --git a/tools/utils/en_dict.txt b/tools/utils/en_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7677d31b9d3f08eef2823c2cf051beeab1f0470b
--- /dev/null
+++ b/tools/utils/en_dict.txt
@@ -0,0 +1,95 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+[
+\
+]
+^
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+|
+}
+~
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+
diff --git a/tools/utils/gen_label.py b/tools/utils/gen_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..65baffa7d3f60be81ed5932016181d75867f8fb8
--- /dev/null
+++ b/tools/utils/gen_label.py
@@ -0,0 +1,68 @@
+import os
+import argparse
+import json
+
+
+def gen_rec_label(input_path, out_label):
+ with open(out_label, "w") as out_file:
+ with open(input_path, "r") as f:
+ for line in f.readlines():
+ tmp = line.strip("\n").replace(" ", "").split(",")
+ img_path, label = tmp[0], tmp[1]
+ label = label.replace('"', "")
+ out_file.write(img_path + "\t" + label + "\n")
+
+
+def gen_det_label(root_path, input_dir, out_label):
+ with open(out_label, "w") as out_file:
+ for label_file in os.listdir(input_dir):
+ img_path = os.path.join(root_path, label_file[3:-4] + ".jpg")
+ label = []
+ with open(
+ os.path.join(input_dir, label_file), "r",
+ encoding="utf-8-sig") as f:
+ for line in f.readlines():
+ tmp = line.strip("\n\r").replace("\xef\xbb\xbf",
+ "").split(",")
+ points = tmp[:8]
+ s = []
+ for i in range(0, len(points), 2):
+ b = points[i:i + 2]
+ b = [int(t) for t in b]
+ s.append(b)
+ result = {"transcription": tmp[8], "points": s}
+ label.append(result)
+
+ out_file.write(img_path + "\t" + json.dumps(
+ label, ensure_ascii=False) + "\n")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--mode",
+ type=str,
+ default="rec",
+ help="Generate rec_label or det_label, can be set rec or det", )
+ parser.add_argument(
+ "--root_path",
+ type=str,
+ default=".",
+ help="The root directory of images.Only takes effect when mode=det ", )
+ parser.add_argument(
+ "--input_path",
+ type=str,
+ default=".",
+ help="Input_label or input path to be converted", )
+ parser.add_argument(
+ "--output_label",
+ type=str,
+ default="out_label.txt",
+ help="Output file name")
+
+ args = parser.parse_args()
+ if args.mode == "rec":
+ print("Generate rec label")
+ gen_rec_label(args.input_path, args.output_label)
+ elif args.mode == "det":
+ gen_det_label(args.root_path, args.input_path, args.output_label)
diff --git a/tools/utils/ic15_dict.txt b/tools/utils/ic15_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..474060366f8a2a00c108d5c743821c0a61867cd5
--- /dev/null
+++ b/tools/utils/ic15_dict.txt
@@ -0,0 +1,36 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
\ No newline at end of file
diff --git a/tools/utils/logging.py b/tools/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd7d641e0452457f3cb81b71351cde1e108c5798
--- /dev/null
+++ b/tools/utils/logging.py
@@ -0,0 +1,56 @@
+import os
+import sys
+import logging
+import functools
+import torch
+import torch.distributed as dist
+
+logger_initialized = {}
+
+
+@functools.lru_cache()
+def get_logger(name="openrec", log_file=None, log_level=logging.DEBUG):
+ """Initialize and get a logger by name.
+ If the logger has not been initialized, this method will initialize the
+ logger by adding one or two handlers, otherwise the initialized logger will
+ be directly returned. During initialization, a StreamHandler will always be
+ added. If `log_file` is specified a FileHandler will also be added.
+ Args:
+ name (str): Logger name.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the logger.
+ log_level (int): The logger level. Note that only the process of
+ rank 0 is affected, and other processes will set the level to
+ "Error" thus be silent most of the time.
+ Returns:
+ logging.Logger: The expected logger.
+ """
+ logger = logging.getLogger(name)
+ if name in logger_initialized:
+ return logger
+ for logger_name in logger_initialized:
+ if name.startswith(logger_name):
+ return logger
+
+ formatter = logging.Formatter(
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s",
+ datefmt="%Y/%m/%d %H:%M:%S")
+
+ stream_handler = logging.StreamHandler(stream=sys.stdout)
+ stream_handler.setFormatter(formatter)
+ logger.addHandler(stream_handler)
+
+ rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else 0
+ if log_file is not None and rank == 0:
+ log_file_folder = os.path.split(log_file)[0]
+ os.makedirs(log_file_folder, exist_ok=True)
+ file_handler = logging.FileHandler(log_file, "a")
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+ if rank == 0:
+ logger.setLevel(log_level)
+ else:
+ logger.setLevel(logging.ERROR)
+ logger_initialized[name] = True
+ logger.propagate = False
+ return logger
diff --git a/tools/utils/poly_nms.py b/tools/utils/poly_nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..797f965e744ba5bef1b9b5a6d2071caff610664e
--- /dev/null
+++ b/tools/utils/poly_nms.py
@@ -0,0 +1,132 @@
+import numpy as np
+from shapely.geometry import Polygon
+
+
+def points2polygon(points):
+ """Convert k points to 1 polygon.
+
+ Args:
+ points (ndarray or list): A ndarray or a list of shape (2k)
+ that indicates k points.
+
+ Returns:
+ polygon (Polygon): A polygon object.
+ """
+ if isinstance(points, list):
+ points = np.array(points)
+
+ assert isinstance(points, np.ndarray)
+ assert (points.size % 2 == 0) and (points.size >= 8)
+
+ point_mat = points.reshape([-1, 2])
+ return Polygon(point_mat)
+
+
+def poly_intersection(poly_det, poly_gt, buffer=0.0001):
+ """Calculate the intersection area between two polygon.
+
+ Args:
+ poly_det (Polygon): A polygon predicted by detector.
+ poly_gt (Polygon): A gt polygon.
+
+ Returns:
+ intersection_area (float): The intersection area between two polygons.
+ """
+ assert isinstance(poly_det, Polygon)
+ assert isinstance(poly_gt, Polygon)
+
+ if buffer == 0:
+ poly_inter = poly_det & poly_gt
+ else:
+ poly_inter = poly_det.buffer(buffer) & poly_gt.buffer(buffer)
+ return poly_inter.area, poly_inter
+
+
+def poly_union(poly_det, poly_gt):
+ """Calculate the union area between two polygon.
+
+ Args:
+ poly_det (Polygon): A polygon predicted by detector.
+ poly_gt (Polygon): A gt polygon.
+
+ Returns:
+ union_area (float): The union area between two polygons.
+ """
+ assert isinstance(poly_det, Polygon)
+ assert isinstance(poly_gt, Polygon)
+
+ area_det = poly_det.area
+ area_gt = poly_gt.area
+ area_inters, _ = poly_intersection(poly_det, poly_gt)
+ return area_det + area_gt - area_inters
+
+
+def valid_boundary(x, with_score=True):
+ num = len(x)
+ if num < 8:
+ return False
+ if num % 2 == 0 and (not with_score):
+ return True
+ if num % 2 == 1 and with_score:
+ return True
+
+ return False
+
+
+def boundary_iou(src, target):
+ """Calculate the IOU between two boundaries.
+
+ Args:
+ src (list): Source boundary.
+ target (list): Target boundary.
+
+ Returns:
+ iou (float): The iou between two boundaries.
+ """
+ assert valid_boundary(src, False)
+ assert valid_boundary(target, False)
+ src_poly = points2polygon(src)
+ target_poly = points2polygon(target)
+
+ return poly_iou(src_poly, target_poly)
+
+
+def poly_iou(poly_det, poly_gt):
+ """Calculate the IOU between two polygons.
+
+ Args:
+ poly_det (Polygon): A polygon predicted by detector.
+ poly_gt (Polygon): A gt polygon.
+
+ Returns:
+ iou (float): The IOU between two polygons.
+ """
+ assert isinstance(poly_det, Polygon)
+ assert isinstance(poly_gt, Polygon)
+ area_inters, _ = poly_intersection(poly_det, poly_gt)
+ area_union = poly_union(poly_det, poly_gt)
+ if area_union == 0:
+ return 0.0
+ return area_inters / area_union
+
+
+def poly_nms(polygons, threshold):
+ assert isinstance(polygons, list)
+
+ polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
+
+ keep_poly = []
+ index = [i for i in range(polygons.shape[0])]
+
+ while len(index) > 0:
+ keep_poly.append(polygons[index[-1]].tolist())
+ A = polygons[index[-1]][:-1]
+ index = np.delete(index, -1)
+ iou_list = np.zeros((len(index), ))
+ for i in range(len(index)):
+ B = polygons[index[i]][:-1]
+ iou_list[i] = boundary_iou(A, B)
+ remove_index = np.where(iou_list > threshold)
+ index = np.delete(index, remove_index)
+
+ return keep_poly
diff --git a/tools/utils/ppocr_keys_v1.txt b/tools/utils/ppocr_keys_v1.txt
new file mode 100644
index 0000000000000000000000000000000000000000..84b885d8352226e49b1d5d791b8f43a663e246aa
--- /dev/null
+++ b/tools/utils/ppocr_keys_v1.txt
@@ -0,0 +1,6623 @@
+'
+疗
+绚
+诚
+娇
+溜
+题
+贿
+者
+廖
+更
+纳
+加
+奉
+公
+一
+就
+汴
+计
+与
+路
+房
+原
+妇
+2
+0
+8
+-
+7
+其
+>
+:
+]
+,
+,
+骑
+刈
+全
+消
+昏
+傈
+安
+久
+钟
+嗅
+不
+影
+处
+驽
+蜿
+资
+关
+椤
+地
+瘸
+专
+问
+忖
+票
+嫉
+炎
+韵
+要
+月
+田
+节
+陂
+鄙
+捌
+备
+拳
+伺
+眼
+网
+盎
+大
+傍
+心
+东
+愉
+汇
+蹿
+科
+每
+业
+里
+航
+晏
+字
+平
+录
+先
+1
+3
+彤
+鲶
+产
+稍
+督
+腴
+有
+象
+岳
+注
+绍
+在
+泺
+文
+定
+核
+名
+水
+过
+理
+让
+偷
+率
+等
+这
+发
+”
+为
+含
+肥
+酉
+相
+鄱
+七
+编
+猥
+锛
+日
+镀
+蒂
+掰
+倒
+辆
+栾
+栗
+综
+涩
+州
+雌
+滑
+馀
+了
+机
+块
+司
+宰
+甙
+兴
+矽
+抚
+保
+用
+沧
+秩
+如
+收
+息
+滥
+页
+疑
+埠
+!
+!
+姥
+异
+橹
+钇
+向
+下
+跄
+的
+椴
+沫
+国
+绥
+獠
+报
+开
+民
+蜇
+何
+分
+凇
+长
+讥
+藏
+掏
+施
+羽
+中
+讲
+派
+嘟
+人
+提
+浼
+间
+世
+而
+古
+多
+倪
+唇
+饯
+控
+庚
+首
+赛
+蜓
+味
+断
+制
+觉
+技
+替
+艰
+溢
+潮
+夕
+钺
+外
+摘
+枋
+动
+双
+单
+啮
+户
+枇
+确
+锦
+曜
+杜
+或
+能
+效
+霜
+盒
+然
+侗
+电
+晁
+放
+步
+鹃
+新
+杖
+蜂
+吒
+濂
+瞬
+评
+总
+隍
+对
+独
+合
+也
+是
+府
+青
+天
+诲
+墙
+组
+滴
+级
+邀
+帘
+示
+已
+时
+骸
+仄
+泅
+和
+遨
+店
+雇
+疫
+持
+巍
+踮
+境
+只
+亨
+目
+鉴
+崤
+闲
+体
+泄
+杂
+作
+般
+轰
+化
+解
+迂
+诿
+蛭
+璀
+腾
+告
+版
+服
+省
+师
+小
+规
+程
+线
+海
+办
+引
+二
+桧
+牌
+砺
+洄
+裴
+修
+图
+痫
+胡
+许
+犊
+事
+郛
+基
+柴
+呼
+食
+研
+奶
+律
+蛋
+因
+葆
+察
+戏
+褒
+戒
+再
+李
+骁
+工
+貂
+油
+鹅
+章
+啄
+休
+场
+给
+睡
+纷
+豆
+器
+捎
+说
+敏
+学
+会
+浒
+设
+诊
+格
+廓
+查
+来
+霓
+室
+溆
+¢
+诡
+寥
+焕
+舜
+柒
+狐
+回
+戟
+砾
+厄
+实
+翩
+尿
+五
+入
+径
+惭
+喹
+股
+宇
+篝
+|
+;
+美
+期
+云
+九
+祺
+扮
+靠
+锝
+槌
+系
+企
+酰
+阊
+暂
+蚕
+忻
+豁
+本
+羹
+执
+条
+钦
+H
+獒
+限
+进
+季
+楦
+于
+芘
+玖
+铋
+茯
+未
+答
+粘
+括
+样
+精
+欠
+矢
+甥
+帷
+嵩
+扣
+令
+仔
+风
+皈
+行
+支
+部
+蓉
+刮
+站
+蜡
+救
+钊
+汗
+松
+嫌
+成
+可
+.
+鹤
+院
+从
+交
+政
+怕
+活
+调
+球
+局
+验
+髌
+第
+韫
+谗
+串
+到
+圆
+年
+米
+/
+*
+友
+忿
+检
+区
+看
+自
+敢
+刃
+个
+兹
+弄
+流
+留
+同
+没
+齿
+星
+聆
+轼
+湖
+什
+三
+建
+蛔
+儿
+椋
+汕
+震
+颧
+鲤
+跟
+力
+情
+璺
+铨
+陪
+务
+指
+族
+训
+滦
+鄣
+濮
+扒
+商
+箱
+十
+召
+慷
+辗
+所
+莞
+管
+护
+臭
+横
+硒
+嗓
+接
+侦
+六
+露
+党
+馋
+驾
+剖
+高
+侬
+妪
+幂
+猗
+绺
+骐
+央
+酐
+孝
+筝
+课
+徇
+缰
+门
+男
+西
+项
+句
+谙
+瞒
+秃
+篇
+教
+碲
+罚
+声
+呐
+景
+前
+富
+嘴
+鳌
+稀
+免
+朋
+啬
+睐
+去
+赈
+鱼
+住
+肩
+愕
+速
+旁
+波
+厅
+健
+茼
+厥
+鲟
+谅
+投
+攸
+炔
+数
+方
+击
+呋
+谈
+绩
+别
+愫
+僚
+躬
+鹧
+胪
+炳
+招
+喇
+膨
+泵
+蹦
+毛
+结
+5
+4
+谱
+识
+陕
+粽
+婚
+拟
+构
+且
+搜
+任
+潘
+比
+郢
+妨
+醪
+陀
+桔
+碘
+扎
+选
+哈
+骷
+楷
+亿
+明
+缆
+脯
+监
+睫
+逻
+婵
+共
+赴
+淝
+凡
+惦
+及
+达
+揖
+谩
+澹
+减
+焰
+蛹
+番
+祁
+柏
+员
+禄
+怡
+峤
+龙
+白
+叽
+生
+闯
+起
+细
+装
+谕
+竟
+聚
+钙
+上
+导
+渊
+按
+艾
+辘
+挡
+耒
+盹
+饪
+臀
+记
+邮
+蕙
+受
+各
+医
+搂
+普
+滇
+朗
+茸
+带
+翻
+酚
+(
+光
+堤
+墟
+蔷
+万
+幻
+〓
+瑙
+辈
+昧
+盏
+亘
+蛀
+吉
+铰
+请
+子
+假
+闻
+税
+井
+诩
+哨
+嫂
+好
+面
+琐
+校
+馊
+鬣
+缂
+营
+访
+炖
+占
+农
+缀
+否
+经
+钚
+棵
+趟
+张
+亟
+吏
+茶
+谨
+捻
+论
+迸
+堂
+玉
+信
+吧
+瞠
+乡
+姬
+寺
+咬
+溏
+苄
+皿
+意
+赉
+宝
+尔
+钰
+艺
+特
+唳
+踉
+都
+荣
+倚
+登
+荐
+丧
+奇
+涵
+批
+炭
+近
+符
+傩
+感
+道
+着
+菊
+虹
+仲
+众
+懈
+濯
+颞
+眺
+南
+释
+北
+缝
+标
+既
+茗
+整
+撼
+迤
+贲
+挎
+耱
+拒
+某
+妍
+卫
+哇
+英
+矶
+藩
+治
+他
+元
+领
+膜
+遮
+穗
+蛾
+飞
+荒
+棺
+劫
+么
+市
+火
+温
+拈
+棚
+洼
+转
+果
+奕
+卸
+迪
+伸
+泳
+斗
+邡
+侄
+涨
+屯
+萋
+胭
+氡
+崮
+枞
+惧
+冒
+彩
+斜
+手
+豚
+随
+旭
+淑
+妞
+形
+菌
+吲
+沱
+争
+驯
+歹
+挟
+兆
+柱
+传
+至
+包
+内
+响
+临
+红
+功
+弩
+衡
+寂
+禁
+老
+棍
+耆
+渍
+织
+害
+氵
+渑
+布
+载
+靥
+嗬
+虽
+苹
+咨
+娄
+库
+雉
+榜
+帜
+嘲
+套
+瑚
+亲
+簸
+欧
+边
+6
+腿
+旮
+抛
+吹
+瞳
+得
+镓
+梗
+厨
+继
+漾
+愣
+憨
+士
+策
+窑
+抑
+躯
+襟
+脏
+参
+贸
+言
+干
+绸
+鳄
+穷
+藜
+音
+折
+详
+)
+举
+悍
+甸
+癌
+黎
+谴
+死
+罩
+迁
+寒
+驷
+袖
+媒
+蒋
+掘
+模
+纠
+恣
+观
+祖
+蛆
+碍
+位
+稿
+主
+澧
+跌
+筏
+京
+锏
+帝
+贴
+证
+糠
+才
+黄
+鲸
+略
+炯
+饱
+四
+出
+园
+犀
+牧
+容
+汉
+杆
+浈
+汰
+瑷
+造
+虫
+瘩
+怪
+驴
+济
+应
+花
+沣
+谔
+夙
+旅
+价
+矿
+以
+考
+s
+u
+呦
+晒
+巡
+茅
+准
+肟
+瓴
+詹
+仟
+褂
+译
+桌
+混
+宁
+怦
+郑
+抿
+些
+余
+鄂
+饴
+攒
+珑
+群
+阖
+岔
+琨
+藓
+预
+环
+洮
+岌
+宀
+杲
+瀵
+最
+常
+囡
+周
+踊
+女
+鼓
+袭
+喉
+简
+范
+薯
+遐
+疏
+粱
+黜
+禧
+法
+箔
+斤
+遥
+汝
+奥
+直
+贞
+撑
+置
+绱
+集
+她
+馅
+逗
+钧
+橱
+魉
+[
+恙
+躁
+唤
+9
+旺
+膘
+待
+脾
+惫
+购
+吗
+依
+盲
+度
+瘿
+蠖
+俾
+之
+镗
+拇
+鲵
+厝
+簧
+续
+款
+展
+啃
+表
+剔
+品
+钻
+腭
+损
+清
+锶
+统
+涌
+寸
+滨
+贪
+链
+吠
+冈
+伎
+迥
+咏
+吁
+览
+防
+迅
+失
+汾
+阔
+逵
+绀
+蔑
+列
+川
+凭
+努
+熨
+揪
+利
+俱
+绉
+抢
+鸨
+我
+即
+责
+膦
+易
+毓
+鹊
+刹
+玷
+岿
+空
+嘞
+绊
+排
+术
+估
+锷
+违
+们
+苟
+铜
+播
+肘
+件
+烫
+审
+鲂
+广
+像
+铌
+惰
+铟
+巳
+胍
+鲍
+康
+憧
+色
+恢
+想
+拷
+尤
+疳
+知
+S
+Y
+F
+D
+A
+峄
+裕
+帮
+握
+搔
+氐
+氘
+难
+墒
+沮
+雨
+叁
+缥
+悴
+藐
+湫
+娟
+苑
+稠
+颛
+簇
+后
+阕
+闭
+蕤
+缚
+怎
+佞
+码
+嘤
+蔡
+痊
+舱
+螯
+帕
+赫
+昵
+升
+烬
+岫
+、
+疵
+蜻
+髁
+蕨
+隶
+烛
+械
+丑
+盂
+梁
+强
+鲛
+由
+拘
+揉
+劭
+龟
+撤
+钩
+呕
+孛
+费
+妻
+漂
+求
+阑
+崖
+秤
+甘
+通
+深
+补
+赃
+坎
+床
+啪
+承
+吼
+量
+暇
+钼
+烨
+阂
+擎
+脱
+逮
+称
+P
+神
+属
+矗
+华
+届
+狍
+葑
+汹
+育
+患
+窒
+蛰
+佼
+静
+槎
+运
+鳗
+庆
+逝
+曼
+疱
+克
+代
+官
+此
+麸
+耧
+蚌
+晟
+例
+础
+榛
+副
+测
+唰
+缢
+迹
+灬
+霁
+身
+岁
+赭
+扛
+又
+菡
+乜
+雾
+板
+读
+陷
+徉
+贯
+郁
+虑
+变
+钓
+菜
+圾
+现
+琢
+式
+乐
+维
+渔
+浜
+左
+吾
+脑
+钡
+警
+T
+啵
+拴
+偌
+漱
+湿
+硕
+止
+骼
+魄
+积
+燥
+联
+踢
+玛
+则
+窿
+见
+振
+畿
+送
+班
+钽
+您
+赵
+刨
+印
+讨
+踝
+籍
+谡
+舌
+崧
+汽
+蔽
+沪
+酥
+绒
+怖
+财
+帖
+肱
+私
+莎
+勋
+羔
+霸
+励
+哼
+帐
+将
+帅
+渠
+纪
+婴
+娩
+岭
+厘
+滕
+吻
+伤
+坝
+冠
+戊
+隆
+瘁
+介
+涧
+物
+黍
+并
+姗
+奢
+蹑
+掣
+垸
+锴
+命
+箍
+捉
+病
+辖
+琰
+眭
+迩
+艘
+绌
+繁
+寅
+若
+毋
+思
+诉
+类
+诈
+燮
+轲
+酮
+狂
+重
+反
+职
+筱
+县
+委
+磕
+绣
+奖
+晋
+濉
+志
+徽
+肠
+呈
+獐
+坻
+口
+片
+碰
+几
+村
+柿
+劳
+料
+获
+亩
+惕
+晕
+厌
+号
+罢
+池
+正
+鏖
+煨
+家
+棕
+复
+尝
+懋
+蜥
+锅
+岛
+扰
+队
+坠
+瘾
+钬
+@
+卧
+疣
+镇
+譬
+冰
+彷
+频
+黯
+据
+垄
+采
+八
+缪
+瘫
+型
+熹
+砰
+楠
+襁
+箐
+但
+嘶
+绳
+啤
+拍
+盥
+穆
+傲
+洗
+盯
+塘
+怔
+筛
+丿
+台
+恒
+喂
+葛
+永
+¥
+烟
+酒
+桦
+书
+砂
+蚝
+缉
+态
+瀚
+袄
+圳
+轻
+蛛
+超
+榧
+遛
+姒
+奘
+铮
+右
+荽
+望
+偻
+卡
+丶
+氰
+附
+做
+革
+索
+戚
+坨
+桷
+唁
+垅
+榻
+岐
+偎
+坛
+莨
+山
+殊
+微
+骇
+陈
+爨
+推
+嗝
+驹
+澡
+藁
+呤
+卤
+嘻
+糅
+逛
+侵
+郓
+酌
+德
+摇
+※
+鬃
+被
+慨
+殡
+羸
+昌
+泡
+戛
+鞋
+河
+宪
+沿
+玲
+鲨
+翅
+哽
+源
+铅
+语
+照
+邯
+址
+荃
+佬
+顺
+鸳
+町
+霭
+睾
+瓢
+夸
+椁
+晓
+酿
+痈
+咔
+侏
+券
+噎
+湍
+签
+嚷
+离
+午
+尚
+社
+锤
+背
+孟
+使
+浪
+缦
+潍
+鞅
+军
+姹
+驶
+笑
+鳟
+鲁
+》
+孽
+钜
+绿
+洱
+礴
+焯
+椰
+颖
+囔
+乌
+孔
+巴
+互
+性
+椽
+哞
+聘
+昨
+早
+暮
+胶
+炀
+隧
+低
+彗
+昝
+铁
+呓
+氽
+藉
+喔
+癖
+瑗
+姨
+权
+胱
+韦
+堑
+蜜
+酋
+楝
+砝
+毁
+靓
+歙
+锲
+究
+屋
+喳
+骨
+辨
+碑
+武
+鸠
+宫
+辜
+烊
+适
+坡
+殃
+培
+佩
+供
+走
+蜈
+迟
+翼
+况
+姣
+凛
+浔
+吃
+飘
+债
+犟
+金
+促
+苛
+崇
+坂
+莳
+畔
+绂
+兵
+蠕
+斋
+根
+砍
+亢
+欢
+恬
+崔
+剁
+餐
+榫
+快
+扶
+‖
+濒
+缠
+鳜
+当
+彭
+驭
+浦
+篮
+昀
+锆
+秸
+钳
+弋
+娣
+瞑
+夷
+龛
+苫
+拱
+致
+%
+嵊
+障
+隐
+弑
+初
+娓
+抉
+汩
+累
+蓖
+"
+唬
+助
+苓
+昙
+押
+毙
+破
+城
+郧
+逢
+嚏
+獭
+瞻
+溱
+婿
+赊
+跨
+恼
+璧
+萃
+姻
+貉
+灵
+炉
+密
+氛
+陶
+砸
+谬
+衔
+点
+琛
+沛
+枳
+层
+岱
+诺
+脍
+榈
+埂
+征
+冷
+裁
+打
+蹴
+素
+瘘
+逞
+蛐
+聊
+激
+腱
+萘
+踵
+飒
+蓟
+吆
+取
+咙
+簋
+涓
+矩
+曝
+挺
+揣
+座
+你
+史
+舵
+焱
+尘
+苏
+笈
+脚
+溉
+榨
+诵
+樊
+邓
+焊
+义
+庶
+儋
+蟋
+蒲
+赦
+呷
+杞
+诠
+豪
+还
+试
+颓
+茉
+太
+除
+紫
+逃
+痴
+草
+充
+鳕
+珉
+祗
+墨
+渭
+烩
+蘸
+慕
+璇
+镶
+穴
+嵘
+恶
+骂
+险
+绋
+幕
+碉
+肺
+戳
+刘
+潞
+秣
+纾
+潜
+銮
+洛
+须
+罘
+销
+瘪
+汞
+兮
+屉
+r
+林
+厕
+质
+探
+划
+狸
+殚
+善
+煊
+烹
+〒
+锈
+逯
+宸
+辍
+泱
+柚
+袍
+远
+蹋
+嶙
+绝
+峥
+娥
+缍
+雀
+徵
+认
+镱
+谷
+=
+贩
+勉
+撩
+鄯
+斐
+洋
+非
+祚
+泾
+诒
+饿
+撬
+威
+晷
+搭
+芍
+锥
+笺
+蓦
+候
+琊
+档
+礁
+沼
+卵
+荠
+忑
+朝
+凹
+瑞
+头
+仪
+弧
+孵
+畏
+铆
+突
+衲
+车
+浩
+气
+茂
+悖
+厢
+枕
+酝
+戴
+湾
+邹
+飚
+攘
+锂
+写
+宵
+翁
+岷
+无
+喜
+丈
+挑
+嗟
+绛
+殉
+议
+槽
+具
+醇
+淞
+笃
+郴
+阅
+饼
+底
+壕
+砚
+弈
+询
+缕
+庹
+翟
+零
+筷
+暨
+舟
+闺
+甯
+撞
+麂
+茌
+蔼
+很
+珲
+捕
+棠
+角
+阉
+媛
+娲
+诽
+剿
+尉
+爵
+睬
+韩
+诰
+匣
+危
+糍
+镯
+立
+浏
+阳
+少
+盆
+舔
+擘
+匪
+申
+尬
+铣
+旯
+抖
+赘
+瓯
+居
+ˇ
+哮
+游
+锭
+茏
+歌
+坏
+甚
+秒
+舞
+沙
+仗
+劲
+潺
+阿
+燧
+郭
+嗖
+霏
+忠
+材
+奂
+耐
+跺
+砀
+输
+岖
+媳
+氟
+极
+摆
+灿
+今
+扔
+腻
+枝
+奎
+药
+熄
+吨
+话
+q
+额
+慑
+嘌
+协
+喀
+壳
+埭
+视
+著
+於
+愧
+陲
+翌
+峁
+颅
+佛
+腹
+聋
+侯
+咎
+叟
+秀
+颇
+存
+较
+罪
+哄
+岗
+扫
+栏
+钾
+羌
+己
+璨
+枭
+霉
+煌
+涸
+衿
+键
+镝
+益
+岢
+奏
+连
+夯
+睿
+冥
+均
+糖
+狞
+蹊
+稻
+爸
+刿
+胥
+煜
+丽
+肿
+璃
+掸
+跚
+灾
+垂
+樾
+濑
+乎
+莲
+窄
+犹
+撮
+战
+馄
+软
+络
+显
+鸢
+胸
+宾
+妲
+恕
+埔
+蝌
+份
+遇
+巧
+瞟
+粒
+恰
+剥
+桡
+博
+讯
+凯
+堇
+阶
+滤
+卖
+斌
+骚
+彬
+兑
+磺
+樱
+舷
+两
+娱
+福
+仃
+差
+找
+桁
+÷
+净
+把
+阴
+污
+戬
+雷
+碓
+蕲
+楚
+罡
+焖
+抽
+妫
+咒
+仑
+闱
+尽
+邑
+菁
+爱
+贷
+沥
+鞑
+牡
+嗉
+崴
+骤
+塌
+嗦
+订
+拮
+滓
+捡
+锻
+次
+坪
+杩
+臃
+箬
+融
+珂
+鹗
+宗
+枚
+降
+鸬
+妯
+阄
+堰
+盐
+毅
+必
+杨
+崃
+俺
+甬
+状
+莘
+货
+耸
+菱
+腼
+铸
+唏
+痤
+孚
+澳
+懒
+溅
+翘
+疙
+杷
+淼
+缙
+骰
+喊
+悉
+砻
+坷
+艇
+赁
+界
+谤
+纣
+宴
+晃
+茹
+归
+饭
+梢
+铡
+街
+抄
+肼
+鬟
+苯
+颂
+撷
+戈
+炒
+咆
+茭
+瘙
+负
+仰
+客
+琉
+铢
+封
+卑
+珥
+椿
+镧
+窨
+鬲
+寿
+御
+袤
+铃
+萎
+砖
+餮
+脒
+裳
+肪
+孕
+嫣
+馗
+嵇
+恳
+氯
+江
+石
+褶
+冢
+祸
+阻
+狈
+羞
+银
+靳
+透
+咳
+叼
+敷
+芷
+啥
+它
+瓤
+兰
+痘
+懊
+逑
+肌
+往
+捺
+坊
+甩
+呻
+〃
+沦
+忘
+膻
+祟
+菅
+剧
+崆
+智
+坯
+臧
+霍
+墅
+攻
+眯
+倘
+拢
+骠
+铐
+庭
+岙
+瓠
+′
+缺
+泥
+迢
+捶
+?
+?
+郏
+喙
+掷
+沌
+纯
+秘
+种
+听
+绘
+固
+螨
+团
+香
+盗
+妒
+埚
+蓝
+拖
+旱
+荞
+铀
+血
+遏
+汲
+辰
+叩
+拽
+幅
+硬
+惶
+桀
+漠
+措
+泼
+唑
+齐
+肾
+念
+酱
+虚
+屁
+耶
+旗
+砦
+闵
+婉
+馆
+拭
+绅
+韧
+忏
+窝
+醋
+葺
+顾
+辞
+倜
+堆
+辋
+逆
+玟
+贱
+疾
+董
+惘
+倌
+锕
+淘
+嘀
+莽
+俭
+笏
+绑
+鲷
+杈
+择
+蟀
+粥
+嗯
+驰
+逾
+案
+谪
+褓
+胫
+哩
+昕
+颚
+鲢
+绠
+躺
+鹄
+崂
+儒
+俨
+丝
+尕
+泌
+啊
+萸
+彰
+幺
+吟
+骄
+苣
+弦
+脊
+瑰
+〈
+诛
+镁
+析
+闪
+剪
+侧
+哟
+框
+螃
+守
+嬗
+燕
+狭
+铈
+缮
+概
+迳
+痧
+鲲
+俯
+售
+笼
+痣
+扉
+挖
+满
+咋
+援
+邱
+扇
+歪
+便
+玑
+绦
+峡
+蛇
+叨
+〖
+泽
+胃
+斓
+喋
+怂
+坟
+猪
+该
+蚬
+炕
+弥
+赞
+棣
+晔
+娠
+挲
+狡
+创
+疖
+铕
+镭
+稷
+挫
+弭
+啾
+翔
+粉
+履
+苘
+哦
+楼
+秕
+铂
+土
+锣
+瘟
+挣
+栉
+习
+享
+桢
+袅
+磨
+桂
+谦
+延
+坚
+蔚
+噗
+署
+谟
+猬
+钎
+恐
+嬉
+雒
+倦
+衅
+亏
+璩
+睹
+刻
+殿
+王
+算
+雕
+麻
+丘
+柯
+骆
+丸
+塍
+谚
+添
+鲈
+垓
+桎
+蚯
+芥
+予
+飕
+镦
+谌
+窗
+醚
+菀
+亮
+搪
+莺
+蒿
+羁
+足
+J
+真
+轶
+悬
+衷
+靛
+翊
+掩
+哒
+炅
+掐
+冼
+妮
+l
+谐
+稚
+荆
+擒
+犯
+陵
+虏
+浓
+崽
+刍
+陌
+傻
+孜
+千
+靖
+演
+矜
+钕
+煽
+杰
+酗
+渗
+伞
+栋
+俗
+泫
+戍
+罕
+沾
+疽
+灏
+煦
+芬
+磴
+叱
+阱
+榉
+湃
+蜀
+叉
+醒
+彪
+租
+郡
+篷
+屎
+良
+垢
+隗
+弱
+陨
+峪
+砷
+掴
+颁
+胎
+雯
+绵
+贬
+沐
+撵
+隘
+篙
+暖
+曹
+陡
+栓
+填
+臼
+彦
+瓶
+琪
+潼
+哪
+鸡
+摩
+啦
+俟
+锋
+域
+耻
+蔫
+疯
+纹
+撇
+毒
+绶
+痛
+酯
+忍
+爪
+赳
+歆
+嘹
+辕
+烈
+册
+朴
+钱
+吮
+毯
+癜
+娃
+谀
+邵
+厮
+炽
+璞
+邃
+丐
+追
+词
+瓒
+忆
+轧
+芫
+谯
+喷
+弟
+半
+冕
+裙
+掖
+墉
+绮
+寝
+苔
+势
+顷
+褥
+切
+衮
+君
+佳
+嫒
+蚩
+霞
+佚
+洙
+逊
+镖
+暹
+唛
+&
+殒
+顶
+碗
+獗
+轭
+铺
+蛊
+废
+恹
+汨
+崩
+珍
+那
+杵
+曲
+纺
+夏
+薰
+傀
+闳
+淬
+姘
+舀
+拧
+卷
+楂
+恍
+讪
+厩
+寮
+篪
+赓
+乘
+灭
+盅
+鞣
+沟
+慎
+挂
+饺
+鼾
+杳
+树
+缨
+丛
+絮
+娌
+臻
+嗳
+篡
+侩
+述
+衰
+矛
+圈
+蚜
+匕
+筹
+匿
+濞
+晨
+叶
+骋
+郝
+挚
+蚴
+滞
+增
+侍
+描
+瓣
+吖
+嫦
+蟒
+匾
+圣
+赌
+毡
+癞
+恺
+百
+曳
+需
+篓
+肮
+庖
+帏
+卿
+驿
+遗
+蹬
+鬓
+骡
+歉
+芎
+胳
+屐
+禽
+烦
+晌
+寄
+媾
+狄
+翡
+苒
+船
+廉
+终
+痞
+殇
+々
+畦
+饶
+改
+拆
+悻
+萄
+£
+瓿
+乃
+訾
+桅
+匮
+溧
+拥
+纱
+铍
+骗
+蕃
+龋
+缬
+父
+佐
+疚
+栎
+醍
+掳
+蓄
+x
+惆
+颜
+鲆
+榆
+〔
+猎
+敌
+暴
+谥
+鲫
+贾
+罗
+玻
+缄
+扦
+芪
+癣
+落
+徒
+臾
+恿
+猩
+托
+邴
+肄
+牵
+春
+陛
+耀
+刊
+拓
+蓓
+邳
+堕
+寇
+枉
+淌
+啡
+湄
+兽
+酷
+萼
+碚
+濠
+萤
+夹
+旬
+戮
+梭
+琥
+椭
+昔
+勺
+蜊
+绐
+晚
+孺
+僵
+宣
+摄
+冽
+旨
+萌
+忙
+蚤
+眉
+噼
+蟑
+付
+契
+瓜
+悼
+颡
+壁
+曾
+窕
+颢
+澎
+仿
+俑
+浑
+嵌
+浣
+乍
+碌
+褪
+乱
+蔟
+隙
+玩
+剐
+葫
+箫
+纲
+围
+伐
+决
+伙
+漩
+瑟
+刑
+肓
+镳
+缓
+蹭
+氨
+皓
+典
+畲
+坍
+铑
+檐
+塑
+洞
+倬
+储
+胴
+淳
+戾
+吐
+灼
+惺
+妙
+毕
+珐
+缈
+虱
+盖
+羰
+鸿
+磅
+谓
+髅
+娴
+苴
+唷
+蚣
+霹
+抨
+贤
+唠
+犬
+誓
+逍
+庠
+逼
+麓
+籼
+釉
+呜
+碧
+秧
+氩
+摔
+霄
+穸
+纨
+辟
+妈
+映
+完
+牛
+缴
+嗷
+炊
+恩
+荔
+茆
+掉
+紊
+慌
+莓
+羟
+阙
+萁
+磐
+另
+蕹
+辱
+鳐
+湮
+吡
+吩
+唐
+睦
+垠
+舒
+圜
+冗
+瞿
+溺
+芾
+囱
+匠
+僳
+汐
+菩
+饬
+漓
+黑
+霰
+浸
+濡
+窥
+毂
+蒡
+兢
+驻
+鹉
+芮
+诙
+迫
+雳
+厂
+忐
+臆
+猴
+鸣
+蚪
+栈
+箕
+羡
+渐
+莆
+捍
+眈
+哓
+趴
+蹼
+埕
+嚣
+骛
+宏
+淄
+斑
+噜
+严
+瑛
+垃
+椎
+诱
+压
+庾
+绞
+焘
+廿
+抡
+迄
+棘
+夫
+纬
+锹
+眨
+瞌
+侠
+脐
+竞
+瀑
+孳
+骧
+遁
+姜
+颦
+荪
+滚
+萦
+伪
+逸
+粳
+爬
+锁
+矣
+役
+趣
+洒
+颔
+诏
+逐
+奸
+甭
+惠
+攀
+蹄
+泛
+尼
+拼
+阮
+鹰
+亚
+颈
+惑
+勒
+〉
+际
+肛
+爷
+刚
+钨
+丰
+养
+冶
+鲽
+辉
+蔻
+画
+覆
+皴
+妊
+麦
+返
+醉
+皂
+擀
+〗
+酶
+凑
+粹
+悟
+诀
+硖
+港
+卜
+z
+杀
+涕
+±
+舍
+铠
+抵
+弛
+段
+敝
+镐
+奠
+拂
+轴
+跛
+袱
+e
+t
+沉
+菇
+俎
+薪
+峦
+秭
+蟹
+历
+盟
+菠
+寡
+液
+肢
+喻
+染
+裱
+悱
+抱
+氙
+赤
+捅
+猛
+跑
+氮
+谣
+仁
+尺
+辊
+窍
+烙
+衍
+架
+擦
+倏
+璐
+瑁
+币
+楞
+胖
+夔
+趸
+邛
+惴
+饕
+虔
+蝎
+§
+哉
+贝
+宽
+辫
+炮
+扩
+饲
+籽
+魏
+菟
+锰
+伍
+猝
+末
+琳
+哚
+蛎
+邂
+呀
+姿
+鄞
+却
+歧
+仙
+恸
+椐
+森
+牒
+寤
+袒
+婆
+虢
+雅
+钉
+朵
+贼
+欲
+苞
+寰
+故
+龚
+坭
+嘘
+咫
+礼
+硷
+兀
+睢
+汶
+’
+铲
+烧
+绕
+诃
+浃
+钿
+哺
+柜
+讼
+颊
+璁
+腔
+洽
+咐
+脲
+簌
+筠
+镣
+玮
+鞠
+谁
+兼
+姆
+挥
+梯
+蝴
+谘
+漕
+刷
+躏
+宦
+弼
+b
+垌
+劈
+麟
+莉
+揭
+笙
+渎
+仕
+嗤
+仓
+配
+怏
+抬
+错
+泯
+镊
+孰
+猿
+邪
+仍
+秋
+鼬
+壹
+歇
+吵
+炼
+<
+尧
+射
+柬
+廷
+胧
+霾
+凳
+隋
+肚
+浮
+梦
+祥
+株
+堵
+退
+L
+鹫
+跎
+凶
+毽
+荟
+炫
+栩
+玳
+甜
+沂
+鹿
+顽
+伯
+爹
+赔
+蛴
+徐
+匡
+欣
+狰
+缸
+雹
+蟆
+疤
+默
+沤
+啜
+痂
+衣
+禅
+w
+i
+h
+辽
+葳
+黝
+钗
+停
+沽
+棒
+馨
+颌
+肉
+吴
+硫
+悯
+劾
+娈
+马
+啧
+吊
+悌
+镑
+峭
+帆
+瀣
+涉
+咸
+疸
+滋
+泣
+翦
+拙
+癸
+钥
+蜒
++
+尾
+庄
+凝
+泉
+婢
+渴
+谊
+乞
+陆
+锉
+糊
+鸦
+淮
+I
+B
+N
+晦
+弗
+乔
+庥
+葡
+尻
+席
+橡
+傣
+渣
+拿
+惩
+麋
+斛
+缃
+矮
+蛏
+岘
+鸽
+姐
+膏
+催
+奔
+镒
+喱
+蠡
+摧
+钯
+胤
+柠
+拐
+璋
+鸥
+卢
+荡
+倾
+^
+_
+珀
+逄
+萧
+塾
+掇
+贮
+笆
+聂
+圃
+冲
+嵬
+M
+滔
+笕
+值
+炙
+偶
+蜱
+搐
+梆
+汪
+蔬
+腑
+鸯
+蹇
+敞
+绯
+仨
+祯
+谆
+梧
+糗
+鑫
+啸
+豺
+囹
+猾
+巢
+柄
+瀛
+筑
+踌
+沭
+暗
+苁
+鱿
+蹉
+脂
+蘖
+牢
+热
+木
+吸
+溃
+宠
+序
+泞
+偿
+拜
+檩
+厚
+朐
+毗
+螳
+吞
+媚
+朽
+担
+蝗
+橘
+畴
+祈
+糟
+盱
+隼
+郜
+惜
+珠
+裨
+铵
+焙
+琚
+唯
+咚
+噪
+骊
+丫
+滢
+勤
+棉
+呸
+咣
+淀
+隔
+蕾
+窈
+饨
+挨
+煅
+短
+匙
+粕
+镜
+赣
+撕
+墩
+酬
+馁
+豌
+颐
+抗
+酣
+氓
+佑
+搁
+哭
+递
+耷
+涡
+桃
+贻
+碣
+截
+瘦
+昭
+镌
+蔓
+氚
+甲
+猕
+蕴
+蓬
+散
+拾
+纛
+狼
+猷
+铎
+埋
+旖
+矾
+讳
+囊
+糜
+迈
+粟
+蚂
+紧
+鲳
+瘢
+栽
+稼
+羊
+锄
+斟
+睁
+桥
+瓮
+蹙
+祉
+醺
+鼻
+昱
+剃
+跳
+篱
+跷
+蒜
+翎
+宅
+晖
+嗑
+壑
+峻
+癫
+屏
+狠
+陋
+袜
+途
+憎
+祀
+莹
+滟
+佶
+溥
+臣
+约
+盛
+峰
+磁
+慵
+婪
+拦
+莅
+朕
+鹦
+粲
+裤
+哎
+疡
+嫖
+琵
+窟
+堪
+谛
+嘉
+儡
+鳝
+斩
+郾
+驸
+酊
+妄
+胜
+贺
+徙
+傅
+噌
+钢
+栅
+庇
+恋
+匝
+巯
+邈
+尸
+锚
+粗
+佟
+蛟
+薹
+纵
+蚊
+郅
+绢
+锐
+苗
+俞
+篆
+淆
+膀
+鲜
+煎
+诶
+秽
+寻
+涮
+刺
+怀
+噶
+巨
+褰
+魅
+灶
+灌
+桉
+藕
+谜
+舸
+薄
+搀
+恽
+借
+牯
+痉
+渥
+愿
+亓
+耘
+杠
+柩
+锔
+蚶
+钣
+珈
+喘
+蹒
+幽
+赐
+稗
+晤
+莱
+泔
+扯
+肯
+菪
+裆
+腩
+豉
+疆
+骜
+腐
+倭
+珏
+唔
+粮
+亡
+润
+慰
+伽
+橄
+玄
+誉
+醐
+胆
+龊
+粼
+塬
+陇
+彼
+削
+嗣
+绾
+芽
+妗
+垭
+瘴
+爽
+薏
+寨
+龈
+泠
+弹
+赢
+漪
+猫
+嘧
+涂
+恤
+圭
+茧
+烽
+屑
+痕
+巾
+赖
+荸
+凰
+腮
+畈
+亵
+蹲
+偃
+苇
+澜
+艮
+换
+骺
+烘
+苕
+梓
+颉
+肇
+哗
+悄
+氤
+涠
+葬
+屠
+鹭
+植
+竺
+佯
+诣
+鲇
+瘀
+鲅
+邦
+移
+滁
+冯
+耕
+癔
+戌
+茬
+沁
+巩
+悠
+湘
+洪
+痹
+锟
+循
+谋
+腕
+鳃
+钠
+捞
+焉
+迎
+碱
+伫
+急
+榷
+奈
+邝
+卯
+辄
+皲
+卟
+醛
+畹
+忧
+稳
+雄
+昼
+缩
+阈
+睑
+扌
+耗
+曦
+涅
+捏
+瞧
+邕
+淖
+漉
+铝
+耦
+禹
+湛
+喽
+莼
+琅
+诸
+苎
+纂
+硅
+始
+嗨
+傥
+燃
+臂
+赅
+嘈
+呆
+贵
+屹
+壮
+肋
+亍
+蚀
+卅
+豹
+腆
+邬
+迭
+浊
+}
+童
+螂
+捐
+圩
+勐
+触
+寞
+汊
+壤
+荫
+膺
+渌
+芳
+懿
+遴
+螈
+泰
+蓼
+蛤
+茜
+舅
+枫
+朔
+膝
+眙
+避
+梅
+判
+鹜
+璜
+牍
+缅
+垫
+藻
+黔
+侥
+惚
+懂
+踩
+腰
+腈
+札
+丞
+唾
+慈
+顿
+摹
+荻
+琬
+~
+斧
+沈
+滂
+胁
+胀
+幄
+莜
+Z
+匀
+鄄
+掌
+绰
+茎
+焚
+赋
+萱
+谑
+汁
+铒
+瞎
+夺
+蜗
+野
+娆
+冀
+弯
+篁
+懵
+灞
+隽
+芡
+脘
+俐
+辩
+芯
+掺
+喏
+膈
+蝈
+觐
+悚
+踹
+蔗
+熠
+鼠
+呵
+抓
+橼
+峨
+畜
+缔
+禾
+崭
+弃
+熊
+摒
+凸
+拗
+穹
+蒙
+抒
+祛
+劝
+闫
+扳
+阵
+醌
+踪
+喵
+侣
+搬
+仅
+荧
+赎
+蝾
+琦
+买
+婧
+瞄
+寓
+皎
+冻
+赝
+箩
+莫
+瞰
+郊
+笫
+姝
+筒
+枪
+遣
+煸
+袋
+舆
+痱
+涛
+母
+〇
+启
+践
+耙
+绲
+盘
+遂
+昊
+搞
+槿
+诬
+纰
+泓
+惨
+檬
+亻
+越
+C
+o
+憩
+熵
+祷
+钒
+暧
+塔
+阗
+胰
+咄
+娶
+魔
+琶
+钞
+邻
+扬
+杉
+殴
+咽
+弓
+〆
+髻
+】
+吭
+揽
+霆
+拄
+殖
+脆
+彻
+岩
+芝
+勃
+辣
+剌
+钝
+嘎
+甄
+佘
+皖
+伦
+授
+徕
+憔
+挪
+皇
+庞
+稔
+芜
+踏
+溴
+兖
+卒
+擢
+饥
+鳞
+煲
+‰
+账
+颗
+叻
+斯
+捧
+鳍
+琮
+讹
+蛙
+纽
+谭
+酸
+兔
+莒
+睇
+伟
+觑
+羲
+嗜
+宜
+褐
+旎
+辛
+卦
+诘
+筋
+鎏
+溪
+挛
+熔
+阜
+晰
+鳅
+丢
+奚
+灸
+呱
+献
+陉
+黛
+鸪
+甾
+萨
+疮
+拯
+洲
+疹
+辑
+叙
+恻
+谒
+允
+柔
+烂
+氏
+逅
+漆
+拎
+惋
+扈
+湟
+纭
+啕
+掬
+擞
+哥
+忽
+涤
+鸵
+靡
+郗
+瓷
+扁
+廊
+怨
+雏
+钮
+敦
+E
+懦
+憋
+汀
+拚
+啉
+腌
+岸
+f
+痼
+瞅
+尊
+咀
+眩
+飙
+忌
+仝
+迦
+熬
+毫
+胯
+篑
+茄
+腺
+凄
+舛
+碴
+锵
+诧
+羯
+後
+漏
+汤
+宓
+仞
+蚁
+壶
+谰
+皑
+铄
+棰
+罔
+辅
+晶
+苦
+牟
+闽
+\
+烃
+饮
+聿
+丙
+蛳
+朱
+煤
+涔
+鳖
+犁
+罐
+荼
+砒
+淦
+妤
+黏
+戎
+孑
+婕
+瑾
+戢
+钵
+枣
+捋
+砥
+衩
+狙
+桠
+稣
+阎
+肃
+梏
+诫
+孪
+昶
+婊
+衫
+嗔
+侃
+塞
+蜃
+樵
+峒
+貌
+屿
+欺
+缫
+阐
+栖
+诟
+珞
+荭
+吝
+萍
+嗽
+恂
+啻
+蜴
+磬
+峋
+俸
+豫
+谎
+徊
+镍
+韬
+魇
+晴
+U
+囟
+猜
+蛮
+坐
+囿
+伴
+亭
+肝
+佗
+蝠
+妃
+胞
+滩
+榴
+氖
+垩
+苋
+砣
+扪
+馏
+姓
+轩
+厉
+夥
+侈
+禀
+垒
+岑
+赏
+钛
+辐
+痔
+披
+纸
+碳
+“
+坞
+蠓
+挤
+荥
+沅
+悔
+铧
+帼
+蒌
+蝇
+a
+p
+y
+n
+g
+哀
+浆
+瑶
+凿
+桶
+馈
+皮
+奴
+苜
+佤
+伶
+晗
+铱
+炬
+优
+弊
+氢
+恃
+甫
+攥
+端
+锌
+灰
+稹
+炝
+曙
+邋
+亥
+眶
+碾
+拉
+萝
+绔
+捷
+浍
+腋
+姑
+菖
+凌
+涞
+麽
+锢
+桨
+潢
+绎
+镰
+殆
+锑
+渝
+铬
+困
+绽
+觎
+匈
+糙
+暑
+裹
+鸟
+盔
+肽
+迷
+綦
+『
+亳
+佝
+俘
+钴
+觇
+骥
+仆
+疝
+跪
+婶
+郯
+瀹
+唉
+脖
+踞
+针
+晾
+忒
+扼
+瞩
+叛
+椒
+疟
+嗡
+邗
+肆
+跆
+玫
+忡
+捣
+咧
+唆
+艄
+蘑
+潦
+笛
+阚
+沸
+泻
+掊
+菽
+贫
+斥
+髂
+孢
+镂
+赂
+麝
+鸾
+屡
+衬
+苷
+恪
+叠
+希
+粤
+爻
+喝
+茫
+惬
+郸
+绻
+庸
+撅
+碟
+宄
+妹
+膛
+叮
+饵
+崛
+嗲
+椅
+冤
+搅
+咕
+敛
+尹
+垦
+闷
+蝉
+霎
+勰
+败
+蓑
+泸
+肤
+鹌
+幌
+焦
+浠
+鞍
+刁
+舰
+乙
+竿
+裔
+。
+茵
+函
+伊
+兄
+丨
+娜
+匍
+謇
+莪
+宥
+似
+蝽
+翳
+酪
+翠
+粑
+薇
+祢
+骏
+赠
+叫
+Q
+噤
+噻
+竖
+芗
+莠
+潭
+俊
+羿
+耜
+O
+郫
+趁
+嗪
+囚
+蹶
+芒
+洁
+笋
+鹑
+敲
+硝
+啶
+堡
+渲
+揩
+』
+携
+宿
+遒
+颍
+扭
+棱
+割
+萜
+蔸
+葵
+琴
+捂
+饰
+衙
+耿
+掠
+募
+岂
+窖
+涟
+蔺
+瘤
+柞
+瞪
+怜
+匹
+距
+楔
+炜
+哆
+秦
+缎
+幼
+茁
+绪
+痨
+恨
+楸
+娅
+瓦
+桩
+雪
+嬴
+伏
+榔
+妥
+铿
+拌
+眠
+雍
+缇
+‘
+卓
+搓
+哌
+觞
+噩
+屈
+哧
+髓
+咦
+巅
+娑
+侑
+淫
+膳
+祝
+勾
+姊
+莴
+胄
+疃
+薛
+蜷
+胛
+巷
+芙
+芋
+熙
+闰
+勿
+窃
+狱
+剩
+钏
+幢
+陟
+铛
+慧
+靴
+耍
+k
+浙
+浇
+飨
+惟
+绗
+祜
+澈
+啼
+咪
+磷
+摞
+诅
+郦
+抹
+跃
+壬
+吕
+肖
+琏
+颤
+尴
+剡
+抠
+凋
+赚
+泊
+津
+宕
+殷
+倔
+氲
+漫
+邺
+涎
+怠
+$
+垮
+荬
+遵
+俏
+叹
+噢
+饽
+蜘
+孙
+筵
+疼
+鞭
+羧
+牦
+箭
+潴
+c
+眸
+祭
+髯
+啖
+坳
+愁
+芩
+驮
+倡
+巽
+穰
+沃
+胚
+怒
+凤
+槛
+剂
+趵
+嫁
+v
+邢
+灯
+鄢
+桐
+睽
+檗
+锯
+槟
+婷
+嵋
+圻
+诗
+蕈
+颠
+遭
+痢
+芸
+怯
+馥
+竭
+锗
+徜
+恭
+遍
+籁
+剑
+嘱
+苡
+龄
+僧
+桑
+潸
+弘
+澶
+楹
+悲
+讫
+愤
+腥
+悸
+谍
+椹
+呢
+桓
+葭
+攫
+阀
+翰
+躲
+敖
+柑
+郎
+笨
+橇
+呃
+魁
+燎
+脓
+葩
+磋
+垛
+玺
+狮
+沓
+砜
+蕊
+锺
+罹
+蕉
+翱
+虐
+闾
+巫
+旦
+茱
+嬷
+枯
+鹏
+贡
+芹
+汛
+矫
+绁
+拣
+禺
+佃
+讣
+舫
+惯
+乳
+趋
+疲
+挽
+岚
+虾
+衾
+蠹
+蹂
+飓
+氦
+铖
+孩
+稞
+瑜
+壅
+掀
+勘
+妓
+畅
+髋
+W
+庐
+牲
+蓿
+榕
+练
+垣
+唱
+邸
+菲
+昆
+婺
+穿
+绡
+麒
+蚱
+掂
+愚
+泷
+涪
+漳
+妩
+娉
+榄
+讷
+觅
+旧
+藤
+煮
+呛
+柳
+腓
+叭
+庵
+烷
+阡
+罂
+蜕
+擂
+猖
+咿
+媲
+脉
+【
+沏
+貅
+黠
+熏
+哲
+烁
+坦
+酵
+兜
+×
+潇
+撒
+剽
+珩
+圹
+乾
+摸
+樟
+帽
+嗒
+襄
+魂
+轿
+憬
+锡
+〕
+喃
+皆
+咖
+隅
+脸
+残
+泮
+袂
+鹂
+珊
+囤
+捆
+咤
+误
+徨
+闹
+淙
+芊
+淋
+怆
+囗
+拨
+梳
+渤
+R
+G
+绨
+蚓
+婀
+幡
+狩
+麾
+谢
+唢
+裸
+旌
+伉
+纶
+裂
+驳
+砼
+咛
+澄
+樨
+蹈
+宙
+澍
+倍
+貔
+操
+勇
+蟠
+摈
+砧
+虬
+够
+缁
+悦
+藿
+撸
+艹
+摁
+淹
+豇
+虎
+榭
+ˉ
+吱
+d
+°
+喧
+荀
+踱
+侮
+奋
+偕
+饷
+犍
+惮
+坑
+璎
+徘
+宛
+妆
+袈
+倩
+窦
+昂
+荏
+乖
+K
+怅
+撰
+鳙
+牙
+袁
+酞
+X
+痿
+琼
+闸
+雁
+趾
+荚
+虻
+涝
+《
+杏
+韭
+偈
+烤
+绫
+鞘
+卉
+症
+遢
+蓥
+诋
+杭
+荨
+匆
+竣
+簪
+辙
+敕
+虞
+丹
+缭
+咩
+黟
+m
+淤
+瑕
+咂
+铉
+硼
+茨
+嶂
+痒
+畸
+敬
+涿
+粪
+窘
+熟
+叔
+嫔
+盾
+忱
+裘
+憾
+梵
+赡
+珙
+咯
+娘
+庙
+溯
+胺
+葱
+痪
+摊
+荷
+卞
+乒
+髦
+寐
+铭
+坩
+胗
+枷
+爆
+溟
+嚼
+羚
+砬
+轨
+惊
+挠
+罄
+竽
+菏
+氧
+浅
+楣
+盼
+枢
+炸
+阆
+杯
+谏
+噬
+淇
+渺
+俪
+秆
+墓
+泪
+跻
+砌
+痰
+垡
+渡
+耽
+釜
+讶
+鳎
+煞
+呗
+韶
+舶
+绷
+鹳
+缜
+旷
+铊
+皱
+龌
+檀
+霖
+奄
+槐
+艳
+蝶
+旋
+哝
+赶
+骞
+蚧
+腊
+盈
+丁
+`
+蜚
+矸
+蝙
+睨
+嚓
+僻
+鬼
+醴
+夜
+彝
+磊
+笔
+拔
+栀
+糕
+厦
+邰
+纫
+逭
+纤
+眦
+膊
+馍
+躇
+烯
+蘼
+冬
+诤
+暄
+骶
+哑
+瘠
+」
+臊
+丕
+愈
+咱
+螺
+擅
+跋
+搏
+硪
+谄
+笠
+淡
+嘿
+骅
+谧
+鼎
+皋
+姚
+歼
+蠢
+驼
+耳
+胬
+挝
+涯
+狗
+蒽
+孓
+犷
+凉
+芦
+箴
+铤
+孤
+嘛
+坤
+V
+茴
+朦
+挞
+尖
+橙
+诞
+搴
+碇
+洵
+浚
+帚
+蜍
+漯
+柘
+嚎
+讽
+芭
+荤
+咻
+祠
+秉
+跖
+埃
+吓
+糯
+眷
+馒
+惹
+娼
+鲑
+嫩
+讴
+轮
+瞥
+靶
+褚
+乏
+缤
+宋
+帧
+删
+驱
+碎
+扑
+俩
+俄
+偏
+涣
+竹
+噱
+皙
+佰
+渚
+唧
+斡
+#
+镉
+刀
+崎
+筐
+佣
+夭
+贰
+肴
+峙
+哔
+艿
+匐
+牺
+镛
+缘
+仡
+嫡
+劣
+枸
+堀
+梨
+簿
+鸭
+蒸
+亦
+稽
+浴
+{
+衢
+束
+槲
+j
+阁
+揍
+疥
+棋
+潋
+聪
+窜
+乓
+睛
+插
+冉
+阪
+苍
+搽
+「
+蟾
+螟
+幸
+仇
+樽
+撂
+慢
+跤
+幔
+俚
+淅
+覃
+觊
+溶
+妖
+帛
+侨
+曰
+妾
+泗
+·
+:
+瀘
+風
+Ë
+(
+)
+∶
+紅
+紗
+瑭
+雲
+頭
+鶏
+財
+許
+•
+¥
+樂
+焗
+麗
+—
+;
+滙
+東
+榮
+繪
+興
+…
+門
+業
+π
+楊
+國
+顧
+é
+盤
+寳
+Λ
+龍
+鳳
+島
+誌
+緣
+結
+銭
+萬
+勝
+祎
+璟
+優
+歡
+臨
+時
+購
+=
+★
+藍
+昇
+鐵
+觀
+勅
+農
+聲
+畫
+兿
+術
+發
+劉
+記
+專
+耑
+園
+書
+壴
+種
+Ο
+●
+褀
+號
+銀
+匯
+敟
+锘
+葉
+橪
+廣
+進
+蒄
+鑽
+阝
+祙
+貢
+鍋
+豊
+夬
+喆
+團
+閣
+開
+燁
+賓
+館
+酡
+沔
+順
++
+硚
+劵
+饸
+陽
+車
+湓
+復
+萊
+氣
+軒
+華
+堃
+迮
+纟
+戶
+馬
+學
+裡
+電
+嶽
+獨
+マ
+シ
+サ
+ジ
+燘
+袪
+環
+❤
+臺
+灣
+専
+賣
+孖
+聖
+攝
+線
+▪
+α
+傢
+俬
+夢
+達
+莊
+喬
+貝
+薩
+劍
+羅
+壓
+棛
+饦
+尃
+璈
+囍
+醫
+G
+I
+A
+#
+N
+鷄
+髙
+嬰
+啓
+約
+隹
+潔
+賴
+藝
+~
+寶
+籣
+麺
+
+嶺
+√
+義
+網
+峩
+長
+∧
+魚
+機
+構
+②
+鳯
+偉
+L
+B
+㙟
+畵
+鴿
+'
+詩
+溝
+嚞
+屌
+藔
+佧
+玥
+蘭
+織
+1
+3
+9
+0
+7
+點
+砭
+鴨
+鋪
+銘
+廳
+弍
+‧
+創
+湯
+坶
+℃
+卩
+骝
+&
+烜
+荘
+當
+潤
+扞
+係
+懷
+碶
+钅
+蚨
+讠
+☆
+叢
+爲
+埗
+涫
+塗
+→
+楽
+現
+鯨
+愛
+瑪
+鈺
+忄
+悶
+藥
+飾
+樓
+視
+孬
+ㆍ
+燚
+苪
+師
+①
+丼
+锽
+│
+韓
+標
+è
+兒
+閏
+匋
+張
+漢
+Ü
+髪
+會
+閑
+檔
+習
+裝
+の
+峯
+菘
+輝
+И
+雞
+釣
+億
+浐
+K
+O
+R
+8
+H
+E
+P
+T
+W
+D
+S
+C
+M
+F
+姌
+饹
+»
+晞
+廰
+ä
+嵯
+鷹
+負
+飲
+絲
+冚
+楗
+澤
+綫
+區
+❋
+←
+質
+靑
+揚
+③
+滬
+統
+産
+協
+﹑
+乸
+畐
+經
+運
+際
+洺
+岽
+為
+粵
+諾
+崋
+豐
+碁
+ɔ
+V
+2
+6
+齋
+誠
+訂
+´
+勑
+雙
+陳
+無
+í
+泩
+媄
+夌
+刂
+i
+c
+t
+o
+r
+a
+嘢
+耄
+燴
+暃
+壽
+媽
+靈
+抻
+體
+唻
+É
+冮
+甹
+鎮
+錦
+ʌ
+蜛
+蠄
+尓
+駕
+戀
+飬
+逹
+倫
+貴
+極
+Я
+Й
+寬
+磚
+嶪
+郎
+職
+|
+間
+n
+d
+剎
+伈
+課
+飛
+橋
+瘊
+№
+譜
+骓
+圗
+滘
+縣
+粿
+咅
+養
+濤
+彳
+®
+%
+Ⅱ
+啰
+㴪
+見
+矞
+薬
+糁
+邨
+鲮
+顔
+罱
+З
+選
+話
+贏
+氪
+俵
+競
+瑩
+繡
+枱
+β
+綉
+á
+獅
+爾
+™
+麵
+戋
+淩
+徳
+個
+劇
+場
+務
+簡
+寵
+h
+實
+膠
+轱
+圖
+築
+嘣
+樹
+㸃
+營
+耵
+孫
+饃
+鄺
+飯
+麯
+遠
+輸
+坫
+孃
+乚
+閃
+鏢
+㎡
+題
+廠
+關
+↑
+爺
+將
+軍
+連
+篦
+覌
+參
+箸
+-
+窠
+棽
+寕
+夀
+爰
+歐
+呙
+閥
+頡
+熱
+雎
+垟
+裟
+凬
+勁
+帑
+馕
+夆
+疌
+枼
+馮
+貨
+蒤
+樸
+彧
+旸
+靜
+龢
+暢
+㐱
+鳥
+珺
+鏡
+灡
+爭
+堷
+廚
+Ó
+騰
+診
+┅
+蘇
+褔
+凱
+頂
+豕
+亞
+帥
+嘬
+⊥
+仺
+桖
+複
+饣
+絡
+穂
+顏
+棟
+納
+▏
+濟
+親
+設
+計
+攵
+埌
+烺
+ò
+頤
+燦
+蓮
+撻
+節
+講
+濱
+濃
+娽
+洳
+朿
+燈
+鈴
+護
+膚
+铔
+過
+補
+Z
+U
+5
+4
+坋
+闿
+䖝
+餘
+缐
+铞
+貿
+铪
+桼
+趙
+鍊
+[
+㐂
+垚
+菓
+揸
+捲
+鐘
+滏
+𣇉
+爍
+輪
+燜
+鴻
+鮮
+動
+鹞
+鷗
+丄
+慶
+鉌
+翥
+飮
+腸
+⇋
+漁
+覺
+來
+熘
+昴
+翏
+鲱
+圧
+鄉
+萭
+頔
+爐
+嫚
+г
+貭
+類
+聯
+幛
+輕
+訓
+鑒
+夋
+锨
+芃
+珣
+䝉
+扙
+嵐
+銷
+處
+ㄱ
+語
+誘
+苝
+歸
+儀
+燒
+楿
+內
+粢
+葒
+奧
+麥
+礻
+滿
+蠔
+穵
+瞭
+態
+鱬
+榞
+硂
+鄭
+黃
+煙
+祐
+奓
+逺
+*
+瑄
+獲
+聞
+薦
+讀
+這
+樣
+決
+問
+啟
+們
+執
+説
+轉
+單
+隨
+唘
+帶
+倉
+庫
+還
+贈
+尙
+皺
+■
+餅
+產
+○
+∈
+報
+狀
+楓
+賠
+琯
+嗮
+禮
+`
+傳
+>
+≤
+嗞
+Φ
+≥
+換
+咭
+∣
+↓
+曬
+ε
+応
+寫
+″
+終
+様
+純
+費
+療
+聨
+凍
+壐
+郵
+ü
+黒
+∫
+製
+塊
+調
+軽
+確
+撃
+級
+馴
+Ⅲ
+涇
+繹
+數
+碼
+證
+狒
+処
+劑
+<
+晧
+賀
+衆
+]
+櫥
+兩
+陰
+絶
+對
+鯉
+憶
+◎
+p
+e
+Y
+蕒
+煖
+頓
+測
+試
+鼽
+僑
+碩
+妝
+帯
+≈
+鐡
+舖
+權
+喫
+倆
+ˋ
+該
+悅
+ā
+俫
+.
+f
+s
+b
+m
+k
+g
+u
+j
+貼
+淨
+濕
+針
+適
+備
+l
+/
+給
+謢
+強
+觸
+衛
+與
+⊙
+$
+緯
+變
+⑴
+⑵
+⑶
+㎏
+殺
+∩
+幚
+─
+價
+▲
+離
+ú
+ó
+飄
+烏
+関
+閟
+﹝
+﹞
+邏
+輯
+鍵
+驗
+訣
+導
+歷
+屆
+層
+▼
+儱
+錄
+熳
+ē
+艦
+吋
+錶
+辧
+飼
+顯
+④
+禦
+販
+気
+対
+枰
+閩
+紀
+幹
+瞓
+貊
+淚
+△
+眞
+墊
+Ω
+獻
+褲
+縫
+緑
+亜
+鉅
+餠
+{
+}
+◆
+蘆
+薈
+█
+◇
+溫
+彈
+晳
+粧
+犸
+穩
+訊
+崬
+凖
+熥
+П
+舊
+條
+紋
+圍
+Ⅳ
+筆
+尷
+難
+雜
+錯
+綁
+識
+頰
+鎖
+艶
+□
+殁
+殼
+⑧
+├
+▕
+鵬
+ǐ
+ō
+ǒ
+糝
+綱
+▎
+μ
+盜
+饅
+醬
+籤
+蓋
+釀
+鹽
+據
+à
+ɡ
+辦
+◥
+彐
+┌
+婦
+獸
+鲩
+伱
+ī
+蒟
+蒻
+齊
+袆
+腦
+寧
+凈
+妳
+煥
+詢
+偽
+謹
+啫
+鯽
+騷
+鱸
+損
+傷
+鎻
+髮
+買
+冏
+儥
+両
+﹢
+∞
+載
+喰
+z
+羙
+悵
+燙
+曉
+員
+組
+徹
+艷
+痠
+鋼
+鼙
+縮
+細
+嚒
+爯
+≠
+維
+"
+鱻
+壇
+厍
+帰
+浥
+犇
+薡
+軎
+²
+應
+醜
+刪
+緻
+鶴
+賜
+噁
+軌
+尨
+镔
+鷺
+槗
+彌
+葚
+濛
+請
+溇
+緹
+賢
+訪
+獴
+瑅
+資
+縤
+陣
+蕟
+栢
+韻
+祼
+恁
+伢
+謝
+劃
+涑
+總
+衖
+踺
+砋
+凉
+籃
+駿
+苼
+瘋
+昽
+紡
+驊
+腎
+﹗
+響
+杋
+剛
+嚴
+禪
+歓
+槍
+傘
+檸
+檫
+炣
+勢
+鏜
+鎢
+銑
+尐
+減
+奪
+惡
+θ
+僮
+婭
+臘
+ū
+ì
+殻
+鉄
+∑
+蛲
+焼
+緖
+續
+紹
+懮
\ No newline at end of file
diff --git a/tools/utils/stats.py b/tools/utils/stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..96654f31fd927656891e66562f50dc470e3f2990
--- /dev/null
+++ b/tools/utils/stats.py
@@ -0,0 +1,58 @@
+import collections
+import numpy as np
+import datetime
+
+__all__ = ["TrainingStats", "Time"]
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size):
+ self.deque = collections.deque(maxlen=window_size)
+
+ def add_value(self, value):
+ self.deque.append(value)
+
+ def get_median_value(self):
+ return np.median(self.deque)
+
+
+def Time():
+ return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
+
+
+class TrainingStats(object):
+ def __init__(self, window_size, stats_keys):
+ self.window_size = window_size
+ self.smoothed_losses_and_metrics = {
+ key: SmoothedValue(window_size)
+ for key in stats_keys
+ }
+
+ def update(self, stats):
+ for k, v in stats.items():
+ if k not in self.smoothed_losses_and_metrics:
+ self.smoothed_losses_and_metrics[k] = SmoothedValue(
+ self.window_size)
+ self.smoothed_losses_and_metrics[k].add_value(v)
+
+ def get(self, extras=None):
+ stats = collections.OrderedDict()
+ if extras:
+ for k, v in extras.items():
+ stats[k] = v
+ for k, v in self.smoothed_losses_and_metrics.items():
+ stats[k] = round(v.get_median_value(), 6)
+
+ return stats
+
+ def log(self, extras=None):
+ d = self.get(extras)
+ strs = []
+ for k, v in d.items():
+ strs.append("{}: {:x<6f}".format(k, v))
+ strs = ", ".join(strs)
+ return strs
diff --git a/tools/utils/utility.py b/tools/utils/utility.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c8e765a978795dee81f37ee1065f5c6abaccb38
--- /dev/null
+++ b/tools/utils/utility.py
@@ -0,0 +1,165 @@
+import logging
+import os
+import cv2
+import numpy as np
+import importlib.util
+import sys
+import subprocess
+
+
+def get_check_global_params(mode):
+ check_params = [
+ "use_gpu",
+ "max_text_length",
+ "image_shape",
+ "image_shape",
+ "character_type",
+ "loss_type",
+ ]
+ if mode == "train_eval":
+ check_params = check_params + [
+ "train_batch_size_per_card",
+ "test_batch_size_per_card",
+ ]
+ elif mode == "test":
+ check_params = check_params + ["test_batch_size_per_card"]
+ return check_params
+
+
+def _check_image_file(path):
+ img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif", "pdf"}
+ return any([path.lower().endswith(e) for e in img_end])
+
+
+def get_image_file_list(img_file):
+ imgs_lists = []
+ if img_file is None or not os.path.exists(img_file):
+ raise Exception("not found any img file in {}".format(img_file))
+
+ if os.path.isfile(img_file) and _check_image_file(img_file):
+ imgs_lists.append(img_file)
+ elif os.path.isdir(img_file):
+ for single_file in os.listdir(img_file):
+ file_path = os.path.join(img_file, single_file)
+ if os.path.isfile(file_path) and _check_image_file(file_path):
+ imgs_lists.append(file_path)
+ if len(imgs_lists) == 0:
+ raise Exception("not found any img file in {}".format(img_file))
+ imgs_lists = sorted(imgs_lists)
+ return imgs_lists
+
+
+def binarize_img(img):
+ if len(img.shape) == 3 and img.shape[2] == 3:
+ gray = cv2.cvtColor(img,
+ cv2.COLOR_BGR2GRAY) # conversion to grayscale image
+ # use cv2 threshold binarization
+ _, gray = cv2.threshold(gray, 0, 255,
+ cv2.THRESH_BINARY + cv2.THRESH_OTSU)
+ img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
+ return img
+
+
+def alpha_to_color(img, alpha_color=(255, 255, 255)):
+ if len(img.shape) == 3 and img.shape[2] == 4:
+ B, G, R, A = cv2.split(img)
+ alpha = A / 255
+
+ R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
+ G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
+ B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
+
+ img = cv2.merge((B, G, R))
+ return img
+
+
+def check_and_read(img_path):
+ if os.path.basename(img_path)[-3:].lower() == "gif":
+ gif = cv2.VideoCapture(img_path)
+ ret, frame = gif.read()
+ if not ret:
+ logger = logging.getLogger("openrec")
+ logger.info("Cannot read {}. This gif image maybe corrupted.")
+ return None, False
+ if len(frame.shape) == 2 or frame.shape[-1] == 1:
+ frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
+ imgvalue = frame[:, :, ::-1]
+ return imgvalue, True, False
+ elif os.path.basename(img_path)[-3:].lower() == "pdf":
+ import fitz
+ from PIL import Image
+
+ imgs = []
+ with fitz.open(img_path) as pdf:
+ for pg in range(0, pdf.page_count):
+ page = pdf[pg]
+ mat = fitz.Matrix(2, 2)
+ pm = page.get_pixmap(matrix=mat, alpha=False)
+
+ # if width or height > 2000 pixels, don't enlarge the image
+ if pm.width > 2000 or pm.height > 2000:
+ pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
+
+ img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+ imgs.append(img)
+ return imgs, False, True
+ return None, False, False
+
+
+def load_vqa_bio_label_maps(label_map_path):
+ with open(label_map_path, "r", encoding="utf-8") as fin:
+ lines = fin.readlines()
+ old_lines = [line.strip() for line in lines]
+ lines = ["O"]
+ for line in old_lines:
+ # "O" has already been in lines
+ if line.upper() in ["OTHER", "OTHERS", "IGNORE"]:
+ continue
+ lines.append(line)
+ labels = ["O"]
+ for line in lines[1:]:
+ labels.append("B-" + line)
+ labels.append("I-" + line)
+ label2id_map = {label.upper(): idx for idx, label in enumerate(labels)}
+ id2label_map = {idx: label.upper() for idx, label in enumerate(labels)}
+ return label2id_map, id2label_map
+
+
+def check_install(module_name, install_name):
+ spec = importlib.util.find_spec(module_name)
+ if spec is None:
+ print(f"Warnning! The {module_name} module is NOT installed")
+ print(
+ f"Try install {module_name} module automatically. You can also try to install manually by pip install {install_name}."
+ )
+ python = sys.executable
+ try:
+ subprocess.check_call(
+ [python, "-m", "pip", "install", install_name],
+ stdout=subprocess.DEVNULL, )
+ print(f"The {module_name} module is now installed")
+ except subprocess.CalledProcessError as exc:
+ raise Exception(
+ f"Install {module_name} failed, please install manually")
+ else:
+ print(f"{module_name} has been installed.")
+
+
+class AverageMeter:
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ """reset"""
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ """update"""
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
diff --git a/tools/utils/visual.py b/tools/utils/visual.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a271923e728c9f6bd82c138305cefa02788dde0
--- /dev/null
+++ b/tools/utils/visual.py
@@ -0,0 +1,117 @@
+import cv2
+import random
+import math
+import numpy as np
+from PIL import Image, ImageDraw, ImageFont
+
+__all__ = ["draw_system", "draw_det"]
+
+
+def draw_det(dt_boxes, img):
+ for box in dt_boxes:
+ box = np.array(box).astype(np.int32).reshape(-1, 2)
+ cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
+ return img
+
+
+def draw_system(
+ image,
+ boxes,
+ txts=None,
+ scores=None,
+ drop_score=0.5,
+ font_path="./doc/fonts/simfang.ttf", ):
+ h, w = image.height, image.width
+ img_left = image.copy()
+ img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
+ random.seed(0)
+
+ draw_left = ImageDraw.Draw(img_left)
+ if txts is None or len(txts) != len(boxes):
+ txts = [None] * len(boxes)
+ for idx, (box, txt) in enumerate(zip(boxes, txts)):
+ if scores is not None and scores[idx] < drop_score:
+ continue
+ color = (random.randint(0, 255), random.randint(0, 255),
+ random.randint(0, 255))
+ draw_left.polygon(box, fill=color)
+ img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
+ pts = np.array(box, np.int32).reshape((-1, 1, 2))
+ cv2.polylines(img_right_text, [pts], True, color, 1)
+ img_right = cv2.bitwise_and(img_right, img_right_text)
+ img_left = Image.blend(image, img_left, 0.5)
+ img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
+ img_show.paste(img_left, (0, 0, w, h))
+ img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
+ return np.array(img_show)
+
+
+def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"):
+ box_height = int(
+ math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][1])**2))
+ box_width = int(
+ math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][1])**2))
+
+ if box_height > 2 * box_width and box_height > 30:
+ img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255))
+ draw_text = ImageDraw.Draw(img_text)
+ if txt:
+ font = create_font(txt, (box_height, box_width), font_path)
+ draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
+ img_text = img_text.transpose(Image.ROTATE_270)
+ else:
+ img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255))
+ draw_text = ImageDraw.Draw(img_text)
+ if txt:
+ font = create_font(txt, (box_width, box_height), font_path)
+ draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
+
+ pts1 = np.float32(
+ [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]])
+ pts2 = np.array(box, dtype=np.float32)
+ M = cv2.getPerspectiveTransform(pts1, pts2)
+
+ img_text = np.array(img_text, dtype=np.uint8)
+ img_right_text = cv2.warpPerspective(
+ img_text,
+ M,
+ img_size,
+ flags=cv2.INTER_NEAREST,
+ borderMode=cv2.BORDER_CONSTANT,
+ borderValue=(255, 255, 255), )
+ return img_right_text
+
+
+def create_font(txt, sz, font_path="./doc/fonts/simfang.ttf"):
+ font_size = int(sz[1] * 0.99)
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+ length = font.getlength(txt)
+ if length > sz[0]:
+ font_size = int(font_size * sz[0] / length)
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+ return font
+
+
+def str_count(s):
+ """
+ Count the number of Chinese characters,
+ a single English character and a single number
+ equal to half the length of Chinese characters.
+ args:
+ s(string): the input of string
+ return(int):
+ the number of Chinese characters
+ """
+ import string
+
+ count_zh = count_pu = 0
+ s_len = len(s)
+ en_dg_count = 0
+ for c in s:
+ if c in string.ascii_letters or c.isdigit() or c.isspace():
+ en_dg_count += 1
+ elif c.isalpha():
+ count_zh += 1
+ else:
+ count_pu += 1
+ return s_len - math.ceil(en_dg_count / 2)