Spaces:
Running
Running
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +176 -0
- Dockerfile +45 -0
- IndicPhotoOCR/__init__.py +0 -0
- IndicPhotoOCR/detection/__init__.py +0 -0
- IndicPhotoOCR/detection/east_config.py +39 -0
- IndicPhotoOCR/detection/east_detector.py +87 -0
- IndicPhotoOCR/detection/east_locality_aware_nms.py +75 -0
- IndicPhotoOCR/detection/east_model.py +242 -0
- IndicPhotoOCR/detection/east_preprossing.py +681 -0
- IndicPhotoOCR/detection/east_utils.py +283 -0
- IndicPhotoOCR/ocr.py +154 -0
- IndicPhotoOCR/recognition/__init__.py +0 -0
- IndicPhotoOCR/recognition/parseq_recogniser.py +215 -0
- IndicPhotoOCR/script_identification/CLIP_identifier.py +201 -0
- IndicPhotoOCR/script_identification/__init__.py +0 -0
- IndicPhotoOCR/theme.py +43 -0
- IndicPhotoOCR/utils/strhub/__init__.py +2 -0
- IndicPhotoOCR/utils/strhub/data/__init__.py +1 -0
- IndicPhotoOCR/utils/strhub/data/aa_overrides.py +46 -0
- IndicPhotoOCR/utils/strhub/data/augment.py +112 -0
- IndicPhotoOCR/utils/strhub/data/dataset.py +148 -0
- IndicPhotoOCR/utils/strhub/data/module.py +157 -0
- IndicPhotoOCR/utils/strhub/data/utils.py +150 -0
- IndicPhotoOCR/utils/strhub/models/__init__.py +1 -0
- IndicPhotoOCR/utils/strhub/models/abinet/LICENSE +25 -0
- IndicPhotoOCR/utils/strhub/models/abinet/__init__.py +13 -0
- IndicPhotoOCR/utils/strhub/models/abinet/attention.py +100 -0
- IndicPhotoOCR/utils/strhub/models/abinet/backbone.py +24 -0
- IndicPhotoOCR/utils/strhub/models/abinet/model.py +31 -0
- IndicPhotoOCR/utils/strhub/models/abinet/model_abinet_iter.py +39 -0
- IndicPhotoOCR/utils/strhub/models/abinet/model_alignment.py +28 -0
- IndicPhotoOCR/utils/strhub/models/abinet/model_language.py +49 -0
- IndicPhotoOCR/utils/strhub/models/abinet/model_vision.py +45 -0
- IndicPhotoOCR/utils/strhub/models/abinet/resnet.py +72 -0
- IndicPhotoOCR/utils/strhub/models/abinet/system.py +215 -0
- IndicPhotoOCR/utils/strhub/models/abinet/transformer.py +198 -0
- IndicPhotoOCR/utils/strhub/models/base.py +221 -0
- IndicPhotoOCR/utils/strhub/models/crnn/LICENSE +21 -0
- IndicPhotoOCR/utils/strhub/models/crnn/__init__.py +13 -0
- IndicPhotoOCR/utils/strhub/models/crnn/model.py +62 -0
- IndicPhotoOCR/utils/strhub/models/crnn/system.py +56 -0
- IndicPhotoOCR/utils/strhub/models/modules.py +20 -0
- IndicPhotoOCR/utils/strhub/models/parseq/__init__.py +0 -0
- IndicPhotoOCR/utils/strhub/models/parseq/model.py +169 -0
- IndicPhotoOCR/utils/strhub/models/parseq/modules.py +176 -0
- IndicPhotoOCR/utils/strhub/models/parseq/system.py +200 -0
- IndicPhotoOCR/utils/strhub/models/trba/__init__.py +13 -0
- IndicPhotoOCR/utils/strhub/models/trba/feature_extraction.py +110 -0
- IndicPhotoOCR/utils/strhub/models/trba/model.py +55 -0
- IndicPhotoOCR/utils/strhub/models/trba/prediction.py +73 -0
.gitignore
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Output directories
|
2 |
+
outputs/
|
3 |
+
multirun/
|
4 |
+
ray_results/
|
5 |
+
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
# requirements/core.*.txt
|
16 |
+
.Python
|
17 |
+
build/
|
18 |
+
develop-eggs/
|
19 |
+
dist/
|
20 |
+
downloads/
|
21 |
+
eggs/
|
22 |
+
.eggs/
|
23 |
+
lib/
|
24 |
+
lib64/
|
25 |
+
parts/
|
26 |
+
sdist/
|
27 |
+
var/
|
28 |
+
wheels/
|
29 |
+
share/python-wheels/
|
30 |
+
*.egg-info/
|
31 |
+
.installed.cfg
|
32 |
+
*.egg
|
33 |
+
MANIFEST
|
34 |
+
|
35 |
+
# PyInstaller
|
36 |
+
# Usually these files are written by a python script from a template
|
37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
38 |
+
*.manifest
|
39 |
+
*.spec
|
40 |
+
|
41 |
+
# Installer logs
|
42 |
+
pip-log.txt
|
43 |
+
pip-delete-this-directory.txt
|
44 |
+
|
45 |
+
# Unit test / coverage reports
|
46 |
+
htmlcov/
|
47 |
+
.tox/
|
48 |
+
.nox/
|
49 |
+
.coverage
|
50 |
+
.coverage.*
|
51 |
+
.cache
|
52 |
+
nosetests.xml
|
53 |
+
coverage.xml
|
54 |
+
*.cover
|
55 |
+
*.py,cover
|
56 |
+
.hypothesis/
|
57 |
+
.pytest_cache/
|
58 |
+
cover/
|
59 |
+
|
60 |
+
# Translations
|
61 |
+
*.mo
|
62 |
+
*.pot
|
63 |
+
|
64 |
+
# Django stuff:
|
65 |
+
*.log
|
66 |
+
local_settings.py
|
67 |
+
db.sqlite3
|
68 |
+
db.sqlite3-journal
|
69 |
+
|
70 |
+
# Flask stuff:
|
71 |
+
instance/
|
72 |
+
.webassets-cache
|
73 |
+
|
74 |
+
# Scrapy stuff:
|
75 |
+
.scrapy
|
76 |
+
|
77 |
+
# Sphinx documentation
|
78 |
+
docs/_build/
|
79 |
+
|
80 |
+
# PyBuilder
|
81 |
+
.pybuilder/
|
82 |
+
target/
|
83 |
+
|
84 |
+
# Jupyter Notebook
|
85 |
+
.ipynb_checkpoints
|
86 |
+
|
87 |
+
# IPython
|
88 |
+
profile_default/
|
89 |
+
ipython_config.py
|
90 |
+
|
91 |
+
# pyenv
|
92 |
+
# For a library or package, you might want to ignore these files since the code is
|
93 |
+
# intended to run in multiple environments; otherwise, check them in:
|
94 |
+
# .python-version
|
95 |
+
|
96 |
+
# pipenv
|
97 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
98 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
99 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
100 |
+
# install all needed dependencies.
|
101 |
+
#Pipfile.lock
|
102 |
+
|
103 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
104 |
+
__pypackages__/
|
105 |
+
|
106 |
+
# Celery stuff
|
107 |
+
celerybeat-schedule
|
108 |
+
celerybeat.pid
|
109 |
+
|
110 |
+
# SageMath parsed files
|
111 |
+
*.sage.py
|
112 |
+
|
113 |
+
# Environments
|
114 |
+
.env
|
115 |
+
.venv
|
116 |
+
env/
|
117 |
+
venv/
|
118 |
+
ENV/
|
119 |
+
env.bak/
|
120 |
+
venv.bak/
|
121 |
+
.python-version
|
122 |
+
|
123 |
+
# Spyder project settings
|
124 |
+
.spyderproject
|
125 |
+
.spyproject
|
126 |
+
|
127 |
+
# Rope project settings
|
128 |
+
.ropeproject
|
129 |
+
|
130 |
+
# mkdocs documentation
|
131 |
+
/site
|
132 |
+
|
133 |
+
# mypy
|
134 |
+
.mypy_cache/
|
135 |
+
.dmypy.json
|
136 |
+
dmypy.json
|
137 |
+
|
138 |
+
# Pyre type checker
|
139 |
+
.pyre/
|
140 |
+
|
141 |
+
# pytype static type analyzer
|
142 |
+
.pytype/
|
143 |
+
|
144 |
+
# Cython debug symbols
|
145 |
+
cython_debug/
|
146 |
+
|
147 |
+
# IDE
|
148 |
+
.idea/
|
149 |
+
|
150 |
+
########## CUSTOM FOLDER ##############
|
151 |
+
README_original.md
|
152 |
+
|
153 |
+
results/
|
154 |
+
images
|
155 |
+
bharatSTR/East/tmp
|
156 |
+
bharatSTR/models
|
157 |
+
bharatSTR/images
|
158 |
+
__pycache__/
|
159 |
+
bharatSTR/
|
160 |
+
|
161 |
+
IndicPhotoOCR/detection/East
|
162 |
+
IndicPhotoOCR/recognition/models
|
163 |
+
|
164 |
+
IndicPhotoOCR/script_identification/images
|
165 |
+
IndicPhotoOCR/script_identification/models
|
166 |
+
|
167 |
+
|
168 |
+
build/
|
169 |
+
dist/
|
170 |
+
test.png
|
171 |
+
static/pics/IndicPhotoOCR.gif
|
172 |
+
input_image.jpg
|
173 |
+
output_image.png
|
174 |
+
|
175 |
+
|
176 |
+
|
Dockerfile
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use NVIDIA PyTorch as the base image
|
2 |
+
FROM nvcr.io/nvidia/pytorch:23.12-py3
|
3 |
+
|
4 |
+
# Install additional dependencies
|
5 |
+
RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6
|
6 |
+
|
7 |
+
# Set environment variables for Miniconda and Conda environment
|
8 |
+
ENV CONDA_DIR /opt/conda
|
9 |
+
ENV PATH $CONDA_DIR/bin:$PATH
|
10 |
+
|
11 |
+
# Install Miniconda
|
12 |
+
RUN apt-get update && apt-get install -y wget && \
|
13 |
+
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
14 |
+
bash Miniconda3-latest-Linux-x86_64.sh -b -p $CONDA_DIR && \
|
15 |
+
rm Miniconda3-latest-Linux-x86_64.sh
|
16 |
+
|
17 |
+
# Create a new Conda environment named "bocr" with Python 3.9
|
18 |
+
RUN conda create -n bocr python=3.9 -y
|
19 |
+
|
20 |
+
# Initialize conda
|
21 |
+
RUN conda init
|
22 |
+
|
23 |
+
# Reload the env configs
|
24 |
+
RUN source ~/.bashrc
|
25 |
+
|
26 |
+
# Make RUN commands use the bocr environment
|
27 |
+
SHELL ["conda", "run", "-n", "bocr", "/bin/bash", "-c"]
|
28 |
+
|
29 |
+
# # Set default shell to bash
|
30 |
+
# SHELL ["/bin/bash", "-c"]
|
31 |
+
|
32 |
+
# # Clone BharatOCR repository
|
33 |
+
# RUN git clone https://github.com/Bhashini-IITJ/BharatOCR.git && \
|
34 |
+
# git switch photoOCR && \
|
35 |
+
# cd IndicPhotoOCR && \
|
36 |
+
# python setup.py sdist bdist_wheel && \
|
37 |
+
# pip install ./dist/IndicPhotoOCR-1.1.0-py3-none-any.whl[cu118] --extra-index-url https://download.pytorch.org/whl/cu118
|
38 |
+
|
39 |
+
# # # Set default command to run BharatOCR
|
40 |
+
# CMD ["conda", "run", "-n", "bocr", "python", "-m", "IndicPhotoOCR.ocr"]
|
41 |
+
|
42 |
+
|
43 |
+
# cd IndicPhotoOCR
|
44 |
+
# sudo docker build -t indicphotoocr:latest .
|
45 |
+
# sudo docker run --gpus all --rm -it indicphotoocr:latest
|
IndicPhotoOCR/__init__.py
ADDED
File without changes
|
IndicPhotoOCR/detection/__init__.py
ADDED
File without changes
|
IndicPhotoOCR/detection/east_config.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# data-config
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
train_data_path = './dataset/train/'
|
6 |
+
train_batch_size_per_gpu = 14 # 14
|
7 |
+
num_workers = 24 # 24
|
8 |
+
gpu_ids = [0] # [0,1,2,3]
|
9 |
+
gpu = 1 # 4
|
10 |
+
input_size = 512 # 预处理后归一化后图像尺寸
|
11 |
+
background_ratio = 3. / 8 # 纯背景样本比例
|
12 |
+
random_scale = np.array([0.5, 1, 2.0, 3.0]) # 提取多尺度图片信息
|
13 |
+
geometry = 'RBOX' # 选择使用几何特征图类型
|
14 |
+
max_image_large_side = 1280
|
15 |
+
max_text_size = 800
|
16 |
+
min_text_size = 10
|
17 |
+
min_crop_side_ratio = 0.1
|
18 |
+
means=[100, 100, 100]
|
19 |
+
pretrained = True # 是否加载基础网络的预训练模型
|
20 |
+
pretrained_basemodel_path = 'IndicPhotoOCR/detection/East/tmp/backbone_net/mobilenet_v2.pth.tar'
|
21 |
+
pre_lr = 1e-4 # 基础网络的初始学习率
|
22 |
+
lr = 1e-3 # 后面网络的初始学习率
|
23 |
+
decay_steps = 50 # decayed_learning_rate = learning_rate * decay_rate ^ (global_epoch / decay_steps)
|
24 |
+
decay_rate = 0.97
|
25 |
+
init_type = 'xavier' # 网络参数初始化方式
|
26 |
+
resume = True # 整体网络是否恢复原来保存的模型
|
27 |
+
checkpoint = 'IndicPhotoOCR/detection/East/tmp/epoch_990_checkpoint.pth.tar' # 指定具体路径及文件名
|
28 |
+
max_epochs = 1000 # 最大迭代epochs数
|
29 |
+
l2_weight_decay = 1e-6 # l2正则化惩罚项权重
|
30 |
+
print_freq = 10 # 每10个batch输出损失结果
|
31 |
+
save_eval_iteration = 50 # 每10个epoch保存一次模型,并做一次评价
|
32 |
+
save_model_path = './tmp/' # 模型保存路径
|
33 |
+
test_img_path = './dataset/full_set' # demo测试样本路径'./demo/test_img/',数据集测试为'./dataset/test/'
|
34 |
+
res_img_path = 'results' # demo结果存放路径'./demo/result_img/',数据集测试为 './dataset/test_result/'
|
35 |
+
write_images = True # 是否输出图像结果
|
36 |
+
score_map_thresh = 0.8 # 置信度阈值
|
37 |
+
box_thresh = 0.1 # 文本框中置信度平均值的阈值
|
38 |
+
nms_thres = 0.2 # 局部非极大抑制IOU阈值
|
39 |
+
compute_hmean_path = './dataset/test_compute_hmean/'
|
IndicPhotoOCR/detection/east_detector.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
|
9 |
+
import IndicPhotoOCR.detection.east_config as cfg
|
10 |
+
from IndicPhotoOCR.detection.east_utils import ModelManager
|
11 |
+
from IndicPhotoOCR.detection.east_model import East
|
12 |
+
import IndicPhotoOCR.detection.east_utils as utils
|
13 |
+
|
14 |
+
# Suppress warnings
|
15 |
+
warnings.filterwarnings("ignore")
|
16 |
+
|
17 |
+
class EASTdetector:
|
18 |
+
def __init__(self, model_name= "east", model_path=None):
|
19 |
+
self.model_path = model_path
|
20 |
+
# self.model_manager = ModelManager()
|
21 |
+
# self.model_manager.ensure_model(model_name)
|
22 |
+
# self.ensure_model(self.model_name)
|
23 |
+
# self.root_model_dir = "BharatSTR/bharatOCR/detection/East/tmp"
|
24 |
+
|
25 |
+
def detect(self, image_path, model_checkpoint, device):
|
26 |
+
# Load image
|
27 |
+
im = cv2.imread(image_path)
|
28 |
+
# im = cv2.imread(image_path)[:, :, ::-1]
|
29 |
+
|
30 |
+
# Initialize the EAST model and load checkpoint
|
31 |
+
model = East()
|
32 |
+
model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids)
|
33 |
+
|
34 |
+
# Load the model checkpoint with weights_only=True
|
35 |
+
checkpoint = torch.load(model_checkpoint, map_location=torch.device(device), weights_only=True)
|
36 |
+
model.load_state_dict(checkpoint['state_dict'])
|
37 |
+
model.eval()
|
38 |
+
|
39 |
+
# Resize image and convert to tensor format
|
40 |
+
im_resized, (ratio_h, ratio_w) = utils.resize_image(im)
|
41 |
+
im_resized = im_resized.astype(np.float32).transpose(2, 0, 1)
|
42 |
+
im_tensor = torch.from_numpy(im_resized).unsqueeze(0).cpu()
|
43 |
+
|
44 |
+
# Inference
|
45 |
+
timer = {'net': 0, 'restore': 0, 'nms': 0}
|
46 |
+
start = time.time()
|
47 |
+
score, geometry = model(im_tensor)
|
48 |
+
timer['net'] = time.time() - start
|
49 |
+
|
50 |
+
# Process output
|
51 |
+
score = score.permute(0, 2, 3, 1).data.cpu().numpy()
|
52 |
+
geometry = geometry.permute(0, 2, 3, 1).data.cpu().numpy()
|
53 |
+
|
54 |
+
# Detect boxes
|
55 |
+
boxes, timer = utils.detect(
|
56 |
+
score_map=score, geo_map=geometry, timer=timer,
|
57 |
+
score_map_thresh=cfg.score_map_thresh, box_thresh=cfg.box_thresh,
|
58 |
+
nms_thres=cfg.box_thresh
|
59 |
+
)
|
60 |
+
bbox_result_dict = {'detections': []}
|
61 |
+
|
62 |
+
# Parse detected boxes and adjust coordinates
|
63 |
+
if boxes is not None:
|
64 |
+
boxes = boxes[:, :8].reshape((-1, 4, 2))
|
65 |
+
boxes[:, :, 0] /= ratio_w
|
66 |
+
boxes[:, :, 1] /= ratio_h
|
67 |
+
for box in boxes:
|
68 |
+
box = utils.sort_poly(box.astype(np.int32))
|
69 |
+
if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5:
|
70 |
+
continue
|
71 |
+
bbox_result_dict['detections'].append([
|
72 |
+
[int(coord[0]), int(coord[1])] for coord in box
|
73 |
+
])
|
74 |
+
|
75 |
+
return bbox_result_dict
|
76 |
+
|
77 |
+
# if __name__ == "__main__":
|
78 |
+
# import argparse
|
79 |
+
# parser = argparse.ArgumentParser(description='Text detection using EAST model')
|
80 |
+
# parser.add_argument('--image_path', type=str, required=True, help='Path to the input image')
|
81 |
+
# parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"')
|
82 |
+
# parser.add_argument('--model_checkpoint', type=str, required=True, help='Path to the model checkpoint file')
|
83 |
+
# args = parser.parse_args()
|
84 |
+
|
85 |
+
# # Run prediction and get results as dictionary
|
86 |
+
# detection_result = predict(args.image_path, args.device, args.model_checkpoint)
|
87 |
+
# print(detection_result)
|
IndicPhotoOCR/detection/east_locality_aware_nms.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
from shapely.geometry import Polygon
|
4 |
+
|
5 |
+
|
6 |
+
def intersection(g, p):
|
7 |
+
g = Polygon(g[:8].reshape((4, 2)))
|
8 |
+
p = Polygon(p[:8].reshape((4, 2)))
|
9 |
+
if not g.is_valid or not p.is_valid:
|
10 |
+
return 0
|
11 |
+
inter = Polygon(g).intersection(Polygon(p)).area
|
12 |
+
union = g.area + p.area - inter
|
13 |
+
if union == 0:
|
14 |
+
return 0
|
15 |
+
else:
|
16 |
+
return inter/union
|
17 |
+
|
18 |
+
|
19 |
+
def weighted_merge(g, p):
|
20 |
+
# g[0]=min(g[0],p[0])
|
21 |
+
# g[1] = min(g[1], p[1])
|
22 |
+
# g[4] = max(g[4], p[4])
|
23 |
+
# g[5]= max(g[5],p[5])
|
24 |
+
#
|
25 |
+
# g[2] = max(g[2], p[2])
|
26 |
+
# g[3] = min(g[3], p[3])
|
27 |
+
# g[6] = min(g[6], p[6])
|
28 |
+
# g[7] = max(g[7], p[7])
|
29 |
+
|
30 |
+
g[:8] = (g[8] * g[:8] + p[8] * p[:8])/(g[8] + p[8])
|
31 |
+
g[8] = (g[8] + p[8])
|
32 |
+
return g
|
33 |
+
|
34 |
+
|
35 |
+
def standard_nms(S, thres):
|
36 |
+
order = np.argsort(S[:, 8])[::-1]
|
37 |
+
keep = []
|
38 |
+
while order.size > 0:
|
39 |
+
i = order[0]
|
40 |
+
keep.append(i)
|
41 |
+
ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
|
42 |
+
|
43 |
+
inds = np.where(ovr <= thres)[0]
|
44 |
+
order = order[inds+1]
|
45 |
+
|
46 |
+
return S[keep]
|
47 |
+
|
48 |
+
|
49 |
+
def nms_locality(polys, thres=0.3):
|
50 |
+
'''
|
51 |
+
locality aware nms of EAST
|
52 |
+
:param polys: a N*9 numpy array. first 8 coordinates, then prob
|
53 |
+
:return: boxes after nms
|
54 |
+
'''
|
55 |
+
S = []
|
56 |
+
p = None
|
57 |
+
for g in polys:
|
58 |
+
if p is not None and intersection(g, p) > thres:
|
59 |
+
p = weighted_merge(g, p)
|
60 |
+
else:
|
61 |
+
if p is not None:
|
62 |
+
S.append(p)
|
63 |
+
p = g
|
64 |
+
if p is not None:
|
65 |
+
S.append(p)
|
66 |
+
|
67 |
+
if len(S) == 0:
|
68 |
+
return np.array([])
|
69 |
+
return standard_nms(np.array(S), thres)
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == '__main__':
|
73 |
+
# 343,350,448,135,474,143,369,359
|
74 |
+
print(Polygon(np.array([[343, 350], [448, 135],
|
75 |
+
[474, 143], [369, 359]])).area)
|
IndicPhotoOCR/detection/east_model.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
from IndicPhotoOCR.detection import east_config as cfg
|
8 |
+
from IndicPhotoOCR.detection import east_utils as utils
|
9 |
+
|
10 |
+
|
11 |
+
def conv_bn(inp, oup, stride):
|
12 |
+
return nn.Sequential(
|
13 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
14 |
+
nn.BatchNorm2d(oup),
|
15 |
+
nn.ReLU6(inplace=True)
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class InvertedResidual(nn.Module):
|
20 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
21 |
+
super(InvertedResidual, self).__init__()
|
22 |
+
self.stride = stride
|
23 |
+
assert stride in [1, 2]
|
24 |
+
|
25 |
+
hidden_dim = round(inp * expand_ratio)
|
26 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
27 |
+
|
28 |
+
if expand_ratio == 1:
|
29 |
+
self.conv = nn.Sequential(
|
30 |
+
# dw
|
31 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
32 |
+
nn.BatchNorm2d(hidden_dim),
|
33 |
+
nn.ReLU6(inplace=True),
|
34 |
+
# pw-linear
|
35 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
36 |
+
nn.BatchNorm2d(oup),
|
37 |
+
)
|
38 |
+
else:
|
39 |
+
self.conv = nn.Sequential(
|
40 |
+
# pw
|
41 |
+
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
42 |
+
nn.BatchNorm2d(hidden_dim),
|
43 |
+
nn.ReLU6(inplace=True),
|
44 |
+
# dw
|
45 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
46 |
+
nn.BatchNorm2d(hidden_dim),
|
47 |
+
nn.ReLU6(inplace=True),
|
48 |
+
# pw-linear
|
49 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
50 |
+
nn.BatchNorm2d(oup),
|
51 |
+
)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
if self.use_res_connect:
|
55 |
+
return x + self.conv(x)
|
56 |
+
else:
|
57 |
+
return self.conv(x)
|
58 |
+
|
59 |
+
|
60 |
+
class MobileNetV2(nn.Module):
|
61 |
+
def __init__(self, width_mult=1.):
|
62 |
+
super(MobileNetV2, self).__init__()
|
63 |
+
block = InvertedResidual
|
64 |
+
input_channel = 32
|
65 |
+
last_channel = 1280
|
66 |
+
interverted_residual_setting = [
|
67 |
+
# t, c, n, s
|
68 |
+
[1, 16, 1, 1],
|
69 |
+
[6, 24, 2, 2],
|
70 |
+
[6, 32, 3, 2],
|
71 |
+
[6, 64, 4, 2],
|
72 |
+
[6, 96, 3, 1],
|
73 |
+
[6, 160, 3, 2],
|
74 |
+
# [6, 320, 1, 1],
|
75 |
+
]
|
76 |
+
|
77 |
+
# building first layer
|
78 |
+
# assert input_size % 32 == 0
|
79 |
+
input_channel = int(input_channel * width_mult)
|
80 |
+
self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
|
81 |
+
self.features = [conv_bn(3, input_channel, 2)]
|
82 |
+
# building inverted residual blocks
|
83 |
+
for t, c, n, s in interverted_residual_setting:
|
84 |
+
output_channel = int(c * width_mult)
|
85 |
+
for i in range(n):
|
86 |
+
if i == 0:
|
87 |
+
self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
|
88 |
+
else:
|
89 |
+
self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
|
90 |
+
input_channel = output_channel
|
91 |
+
|
92 |
+
# make it nn.Sequential
|
93 |
+
self.features = nn.Sequential(*self.features)
|
94 |
+
|
95 |
+
self._initialize_weights()
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
x = self.features(x)
|
99 |
+
# x = x.mean(3).mean(2)
|
100 |
+
# x = self.classifier(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
def _initialize_weights(self):
|
104 |
+
for m in self.modules():
|
105 |
+
if isinstance(m, nn.Conv2d):
|
106 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
107 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
108 |
+
if m.bias is not None:
|
109 |
+
m.bias.data.zero_()
|
110 |
+
elif isinstance(m, nn.BatchNorm2d):
|
111 |
+
m.weight.data.fill_(1)
|
112 |
+
m.bias.data.zero_()
|
113 |
+
elif isinstance(m, nn.Linear):
|
114 |
+
n = m.weight.size(1)
|
115 |
+
m.weight.data.normal_(0, 0.01)
|
116 |
+
m.bias.data.zero_()
|
117 |
+
|
118 |
+
|
119 |
+
def mobilenet(pretrained=True, **kwargs):
|
120 |
+
"""
|
121 |
+
Constructs a ResNet-50 model.
|
122 |
+
Args:
|
123 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
124 |
+
"""
|
125 |
+
model = MobileNetV2()
|
126 |
+
if pretrained:
|
127 |
+
model_dict = model.state_dict()
|
128 |
+
pretrained_dict = torch.load(cfg.pretrained_basemodel_path,map_location=torch.device('cpu'), weights_only=True)
|
129 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
130 |
+
model_dict.update(pretrained_dict)
|
131 |
+
model.load_state_dict(model_dict)
|
132 |
+
# state_dict = torch.load(cfg.pretrained_basemodel_path) # add map_location='cpu' if no gpu
|
133 |
+
# model.load_state_dict(state_dict)
|
134 |
+
|
135 |
+
return model
|
136 |
+
|
137 |
+
|
138 |
+
class East(nn.Module):
|
139 |
+
def __init__(self):
|
140 |
+
super(East, self).__init__()
|
141 |
+
self.mobilenet = mobilenet(True)
|
142 |
+
# self.si for stage i
|
143 |
+
self.s1 = nn.Sequential(*list(self.mobilenet.children())[0][0:4])
|
144 |
+
self.s2 = nn.Sequential(*list(self.mobilenet.children())[0][4:7])
|
145 |
+
self.s3 = nn.Sequential(*list(self.mobilenet.children())[0][7:14])
|
146 |
+
self.s4 = nn.Sequential(*list(self.mobilenet.children())[0][14:17])
|
147 |
+
|
148 |
+
self.conv1 = nn.Conv2d(160+96, 128, 1)
|
149 |
+
self.bn1 = nn.BatchNorm2d(128)
|
150 |
+
self.relu1 = nn.ReLU()
|
151 |
+
|
152 |
+
self.conv2 = nn.Conv2d(128, 128, 3, padding=1)
|
153 |
+
self.bn2 = nn.BatchNorm2d(128)
|
154 |
+
self.relu2 = nn.ReLU()
|
155 |
+
|
156 |
+
self.conv3 = nn.Conv2d(128+32, 64, 1)
|
157 |
+
self.bn3 = nn.BatchNorm2d(64)
|
158 |
+
self.relu3 = nn.ReLU()
|
159 |
+
|
160 |
+
self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
|
161 |
+
self.bn4 = nn.BatchNorm2d(64)
|
162 |
+
self.relu4 = nn.ReLU()
|
163 |
+
|
164 |
+
self.conv5 = nn.Conv2d(64+24, 64, 1)
|
165 |
+
self.bn5 = nn.BatchNorm2d(64)
|
166 |
+
self.relu5 = nn.ReLU()
|
167 |
+
|
168 |
+
self.conv6 = nn.Conv2d(64, 32, 3, padding=1)
|
169 |
+
self.bn6 = nn.BatchNorm2d(32)
|
170 |
+
self.relu6 = nn.ReLU()
|
171 |
+
|
172 |
+
self.conv7 = nn.Conv2d(32, 32, 3, padding=1)
|
173 |
+
self.bn7 = nn.BatchNorm2d(32)
|
174 |
+
self.relu7 = nn.ReLU()
|
175 |
+
|
176 |
+
self.conv8 = nn.Conv2d(32, 1, 1)
|
177 |
+
self.sigmoid1 = nn.Sigmoid()
|
178 |
+
self.conv9 = nn.Conv2d(32, 4, 1)
|
179 |
+
self.sigmoid2 = nn.Sigmoid()
|
180 |
+
self.conv10 = nn.Conv2d(32, 1, 1)
|
181 |
+
self.sigmoid3 = nn.Sigmoid()
|
182 |
+
self.unpool1 = nn.Upsample(scale_factor=2, mode='bilinear')
|
183 |
+
self.unpool2 = nn.Upsample(scale_factor=2, mode='bilinear')
|
184 |
+
self.unpool3 = nn.Upsample(scale_factor=2, mode='bilinear')
|
185 |
+
|
186 |
+
# utils.init_weights([self.conv1,self.conv2,self.conv3,self.conv4,
|
187 |
+
# self.conv5,self.conv6,self.conv7,self.conv8,
|
188 |
+
# self.conv9,self.conv10,self.bn1,self.bn2,
|
189 |
+
# self.bn3,self.bn4,self.bn5,self.bn6,self.bn7])
|
190 |
+
|
191 |
+
def forward(self, images):
|
192 |
+
images = utils.mean_image_subtraction(images)
|
193 |
+
|
194 |
+
f0 = self.s1(images)
|
195 |
+
f1 = self.s2(f0)
|
196 |
+
f2 = self.s3(f1)
|
197 |
+
f3 = self.s4(f2)
|
198 |
+
|
199 |
+
# _, f = self.mobilenet(images)
|
200 |
+
h = f3 # bs 2048 w/32 h/32
|
201 |
+
g = (self.unpool1(h)) # bs 2048 w/16 h/16
|
202 |
+
c = self.conv1(torch.cat((g, f2), 1))
|
203 |
+
c = self.bn1(c)
|
204 |
+
c = self.relu1(c)
|
205 |
+
|
206 |
+
h = self.conv2(c) # bs 128 w/16 h/16
|
207 |
+
h = self.bn2(h)
|
208 |
+
h = self.relu2(h)
|
209 |
+
g = self.unpool2(h) # bs 128 w/8 h/8
|
210 |
+
c = self.conv3(torch.cat((g, f1), 1))
|
211 |
+
c = self.bn3(c)
|
212 |
+
c = self.relu3(c)
|
213 |
+
|
214 |
+
h = self.conv4(c) # bs 64 w/8 h/8
|
215 |
+
h = self.bn4(h)
|
216 |
+
h = self.relu4(h)
|
217 |
+
g = self.unpool3(h) # bs 64 w/4 h/4
|
218 |
+
c = self.conv5(torch.cat((g, f0), 1))
|
219 |
+
c = self.bn5(c)
|
220 |
+
c = self.relu5(c)
|
221 |
+
|
222 |
+
h = self.conv6(c) # bs 32 w/4 h/4
|
223 |
+
h = self.bn6(h)
|
224 |
+
h = self.relu6(h)
|
225 |
+
g = self.conv7(h) # bs 32 w/4 h/4
|
226 |
+
g = self.bn7(g)
|
227 |
+
g = self.relu7(g)
|
228 |
+
|
229 |
+
F_score = self.conv8(g) # bs 1 w/4 h/4
|
230 |
+
F_score = self.sigmoid1(F_score)
|
231 |
+
geo_map = self.conv9(g)
|
232 |
+
geo_map = self.sigmoid2(geo_map) * 512
|
233 |
+
angle_map = self.conv10(g)
|
234 |
+
angle_map = self.sigmoid3(angle_map)
|
235 |
+
angle_map = (angle_map - 0.5) * math.pi / 2
|
236 |
+
|
237 |
+
F_geometry = torch.cat((geo_map, angle_map), 1) # bs 5 w/4 h/4
|
238 |
+
|
239 |
+
return F_score, F_geometry
|
240 |
+
|
241 |
+
|
242 |
+
model=East()
|
IndicPhotoOCR/detection/east_preprossing.py
ADDED
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# coding:utf-8
|
3 |
+
import glob
|
4 |
+
import csv
|
5 |
+
import cv2
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
from shapely.geometry import Polygon
|
9 |
+
|
10 |
+
|
11 |
+
from IndicPhotoOCR.detection import east_config as cfg
|
12 |
+
from IndicPhotoOCR.detection import east_utils
|
13 |
+
|
14 |
+
|
15 |
+
def get_images(img_root):
|
16 |
+
files = []
|
17 |
+
for ext in ['jpg']:
|
18 |
+
files.extend(glob.glob(
|
19 |
+
os.path.join(img_root, '*.{}'.format(ext))))
|
20 |
+
# print(glob.glob(
|
21 |
+
# os.path.join(FLAGS.training_data_path, '*.{}'.format(ext))))
|
22 |
+
return files
|
23 |
+
|
24 |
+
|
25 |
+
def load_annoataion(p):
|
26 |
+
'''
|
27 |
+
load annotation from the text file
|
28 |
+
:param p:
|
29 |
+
:return:
|
30 |
+
'''
|
31 |
+
text_polys = []
|
32 |
+
text_tags = []
|
33 |
+
if not os.path.exists(p):
|
34 |
+
return np.array(text_polys, dtype=np.float32)
|
35 |
+
with open(p, 'r', encoding='UTF-8') as f:
|
36 |
+
reader = csv.reader(f)
|
37 |
+
for line in reader:
|
38 |
+
label = line[-1]
|
39 |
+
# strip BOM. \ufeff for python3, \xef\xbb\bf for python2
|
40 |
+
line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in line]
|
41 |
+
|
42 |
+
x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8]))
|
43 |
+
text_polys.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
|
44 |
+
# print(text_polys)
|
45 |
+
if label == '*' or label == '###':
|
46 |
+
text_tags.append(True)
|
47 |
+
else:
|
48 |
+
text_tags.append(False)
|
49 |
+
return np.array(text_polys, dtype=np.float32), np.array(text_tags, dtype=np.bool)
|
50 |
+
|
51 |
+
|
52 |
+
def polygon_area(poly):
|
53 |
+
'''
|
54 |
+
compute area of a polygon
|
55 |
+
:param poly:
|
56 |
+
:return:
|
57 |
+
'''
|
58 |
+
edge = [
|
59 |
+
(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
60 |
+
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
61 |
+
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
62 |
+
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])
|
63 |
+
]
|
64 |
+
return np.sum(edge) / 2.
|
65 |
+
|
66 |
+
|
67 |
+
def check_and_validate_polys(polys, tags, xxx_todo_changeme):
|
68 |
+
'''
|
69 |
+
check so that the text poly is in the same direction,
|
70 |
+
and also filter some invalid polygons
|
71 |
+
:param polys:
|
72 |
+
:param tags:
|
73 |
+
:return:
|
74 |
+
'''
|
75 |
+
(h, w) = xxx_todo_changeme
|
76 |
+
if polys.shape[0] == 0:
|
77 |
+
return polys
|
78 |
+
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
79 |
+
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
80 |
+
|
81 |
+
validated_polys = []
|
82 |
+
validated_tags = []
|
83 |
+
|
84 |
+
# 判断四边形的点时针方向,以及是否是有效四边形
|
85 |
+
for poly, tag in zip(polys, tags):
|
86 |
+
p_area = polygon_area(poly)
|
87 |
+
if abs(p_area) < 1:
|
88 |
+
# print poly
|
89 |
+
print('invalid poly')
|
90 |
+
continue
|
91 |
+
if p_area > 0:
|
92 |
+
print('poly in wrong direction')
|
93 |
+
poly = poly[(0, 3, 2, 1), :]
|
94 |
+
validated_polys.append(poly)
|
95 |
+
validated_tags.append(tag)
|
96 |
+
return np.array(validated_polys), np.array(validated_tags)
|
97 |
+
|
98 |
+
|
99 |
+
def crop_area(im, polys, tags, crop_background=False, max_tries=100):
|
100 |
+
'''
|
101 |
+
make random crop from the input image
|
102 |
+
:param im:
|
103 |
+
:param polys:
|
104 |
+
:param tags:
|
105 |
+
:param crop_background:
|
106 |
+
:param max_tries:
|
107 |
+
:return:
|
108 |
+
'''
|
109 |
+
h, w, _ = im.shape
|
110 |
+
pad_h = h // 10
|
111 |
+
pad_w = w // 10
|
112 |
+
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
113 |
+
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
114 |
+
for poly in polys:
|
115 |
+
poly = np.round(poly, decimals=0).astype(np.int32)
|
116 |
+
minx = np.min(poly[:, 0])
|
117 |
+
maxx = np.max(poly[:, 0])
|
118 |
+
w_array[minx + pad_w:maxx + pad_w] = 1
|
119 |
+
miny = np.min(poly[:, 1])
|
120 |
+
maxy = np.max(poly[:, 1])
|
121 |
+
h_array[miny + pad_h:maxy + pad_h] = 1
|
122 |
+
# ensure the cropped area not across a text,保证裁剪区域不能与文本交叉
|
123 |
+
h_axis = np.where(h_array == 0)[0]
|
124 |
+
w_axis = np.where(w_array == 0)[0]
|
125 |
+
if len(h_axis) == 0 or len(w_axis) == 0:
|
126 |
+
return im, polys, tags
|
127 |
+
for i in range(max_tries): # 试验50次
|
128 |
+
xx = np.random.choice(w_axis, size=2)
|
129 |
+
xmin = np.min(xx) - pad_w
|
130 |
+
xmax = np.max(xx) - pad_w
|
131 |
+
xmin = np.clip(xmin, 0, w - 1)
|
132 |
+
xmax = np.clip(xmax, 0, w - 1)
|
133 |
+
yy = np.random.choice(h_axis, size=2)
|
134 |
+
ymin = np.min(yy) - pad_h
|
135 |
+
ymax = np.max(yy) - pad_h
|
136 |
+
ymin = np.clip(ymin, 0, h - 1)
|
137 |
+
ymax = np.clip(ymax, 0, h - 1)
|
138 |
+
if xmax - xmin < cfg.min_crop_side_ratio * w or ymax - ymin < cfg.min_crop_side_ratio * h:
|
139 |
+
# area too small
|
140 |
+
continue
|
141 |
+
if polys.shape[0] != 0:
|
142 |
+
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
|
143 |
+
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
|
144 |
+
selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
145 |
+
else:
|
146 |
+
selected_polys = []
|
147 |
+
if len(selected_polys) == 0:
|
148 |
+
# no text in this area
|
149 |
+
if crop_background:
|
150 |
+
return im[ymin:ymax + 1, xmin:xmax + 1, :], polys[selected_polys], tags[selected_polys]
|
151 |
+
else:
|
152 |
+
continue
|
153 |
+
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
154 |
+
polys = polys[selected_polys]
|
155 |
+
tags = tags[selected_polys]
|
156 |
+
polys[:, :, 0] -= xmin
|
157 |
+
polys[:, :, 1] -= ymin
|
158 |
+
return im, polys, tags
|
159 |
+
|
160 |
+
return im, polys, tags
|
161 |
+
|
162 |
+
|
163 |
+
def shrink_poly(poly, r):
|
164 |
+
'''
|
165 |
+
fit a poly inside the origin poly, maybe bugs here...
|
166 |
+
used for generate the score map
|
167 |
+
:param poly: the text poly
|
168 |
+
:param r: r in the paper
|
169 |
+
:return: the shrinked poly
|
170 |
+
'''
|
171 |
+
# shrink ratio
|
172 |
+
R = 0.3
|
173 |
+
# find the longer pair
|
174 |
+
if np.linalg.norm(poly[0] - poly[1]) + np.linalg.norm(poly[2] - poly[3]) > \
|
175 |
+
np.linalg.norm(poly[0] - poly[3]) + np.linalg.norm(poly[1] - poly[2]):
|
176 |
+
# first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
|
177 |
+
## p0, p1
|
178 |
+
theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
|
179 |
+
poly[0][0] += R * r[0] * np.cos(theta)
|
180 |
+
poly[0][1] += R * r[0] * np.sin(theta)
|
181 |
+
poly[1][0] -= R * r[1] * np.cos(theta)
|
182 |
+
poly[1][1] -= R * r[1] * np.sin(theta)
|
183 |
+
## p2, p3
|
184 |
+
theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
|
185 |
+
poly[3][0] += R * r[3] * np.cos(theta)
|
186 |
+
poly[3][1] += R * r[3] * np.sin(theta)
|
187 |
+
poly[2][0] -= R * r[2] * np.cos(theta)
|
188 |
+
poly[2][1] -= R * r[2] * np.sin(theta)
|
189 |
+
## p0, p3
|
190 |
+
theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
|
191 |
+
poly[0][0] += R * r[0] * np.sin(theta)
|
192 |
+
poly[0][1] += R * r[0] * np.cos(theta)
|
193 |
+
poly[3][0] -= R * r[3] * np.sin(theta)
|
194 |
+
poly[3][1] -= R * r[3] * np.cos(theta)
|
195 |
+
## p1, p2
|
196 |
+
theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
|
197 |
+
poly[1][0] += R * r[1] * np.sin(theta)
|
198 |
+
poly[1][1] += R * r[1] * np.cos(theta)
|
199 |
+
poly[2][0] -= R * r[2] * np.sin(theta)
|
200 |
+
poly[2][1] -= R * r[2] * np.cos(theta)
|
201 |
+
else:
|
202 |
+
## p0, p3
|
203 |
+
# print poly
|
204 |
+
theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
|
205 |
+
poly[0][0] += R * r[0] * np.sin(theta)
|
206 |
+
poly[0][1] += R * r[0] * np.cos(theta)
|
207 |
+
poly[3][0] -= R * r[3] * np.sin(theta)
|
208 |
+
poly[3][1] -= R * r[3] * np.cos(theta)
|
209 |
+
## p1, p2
|
210 |
+
theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
|
211 |
+
poly[1][0] += R * r[1] * np.sin(theta)
|
212 |
+
poly[1][1] += R * r[1] * np.cos(theta)
|
213 |
+
poly[2][0] -= R * r[2] * np.sin(theta)
|
214 |
+
poly[2][1] -= R * r[2] * np.cos(theta)
|
215 |
+
## p0, p1
|
216 |
+
theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
|
217 |
+
poly[0][0] += R * r[0] * np.cos(theta)
|
218 |
+
poly[0][1] += R * r[0] * np.sin(theta)
|
219 |
+
poly[1][0] -= R * r[1] * np.cos(theta)
|
220 |
+
poly[1][1] -= R * r[1] * np.sin(theta)
|
221 |
+
## p2, p3
|
222 |
+
theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
|
223 |
+
poly[3][0] += R * r[3] * np.cos(theta)
|
224 |
+
poly[3][1] += R * r[3] * np.sin(theta)
|
225 |
+
poly[2][0] -= R * r[2] * np.cos(theta)
|
226 |
+
poly[2][1] -= R * r[2] * np.sin(theta)
|
227 |
+
return poly
|
228 |
+
|
229 |
+
|
230 |
+
# def point_dist_to_line(p1, p2, p3):
|
231 |
+
# # compute the distance from p3 to p1-p2
|
232 |
+
# return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1)
|
233 |
+
|
234 |
+
|
235 |
+
# 点p3到直线p12的距离
|
236 |
+
def point_dist_to_line(p1, p2, p3):
|
237 |
+
# compute the distance from p3 to p1-p2
|
238 |
+
# return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1)
|
239 |
+
a = np.linalg.norm(p1 - p2)
|
240 |
+
b = np.linalg.norm(p2 - p3)
|
241 |
+
c = np.linalg.norm(p3 - p1)
|
242 |
+
s = (a + b + c) / 2.0
|
243 |
+
area = np.abs((s * (s - a) * (s - b) * (s - c))) ** 0.5
|
244 |
+
if a < 1.0:
|
245 |
+
return (b + c) / 2.0
|
246 |
+
return 2 * area / a
|
247 |
+
|
248 |
+
|
249 |
+
def fit_line(p1, p2):
|
250 |
+
# fit a line ax+by+c = 0
|
251 |
+
if p1[0] == p1[1]:
|
252 |
+
return [1., 0., -p1[0]]
|
253 |
+
else:
|
254 |
+
[k, b] = np.polyfit(p1, p2, deg=1)
|
255 |
+
return [k, -1., b]
|
256 |
+
|
257 |
+
|
258 |
+
def line_cross_point(line1, line2):
|
259 |
+
# line1 0= ax+by+c, compute the cross point of line1 and line2
|
260 |
+
if line1[0] != 0 and line1[0] == line2[0]:
|
261 |
+
print('Cross point does not exist')
|
262 |
+
return None
|
263 |
+
if line1[0] == 0 and line2[0] == 0:
|
264 |
+
print('Cross point does not exist')
|
265 |
+
return None
|
266 |
+
if line1[1] == 0:
|
267 |
+
x = -line1[2]
|
268 |
+
y = line2[0] * x + line2[2]
|
269 |
+
elif line2[1] == 0:
|
270 |
+
x = -line2[2]
|
271 |
+
y = line1[0] * x + line1[2]
|
272 |
+
else:
|
273 |
+
k1, _, b1 = line1
|
274 |
+
k2, _, b2 = line2
|
275 |
+
x = -(b1 - b2) / (k1 - k2)
|
276 |
+
y = k1 * x + b1
|
277 |
+
return np.array([x, y], dtype=np.float32)
|
278 |
+
|
279 |
+
|
280 |
+
def line_verticle(line, point):
|
281 |
+
# get the verticle line from line across point
|
282 |
+
if line[1] == 0:
|
283 |
+
verticle = [0, -1, point[1]]
|
284 |
+
else:
|
285 |
+
if line[0] == 0:
|
286 |
+
verticle = [1, 0, -point[0]]
|
287 |
+
else:
|
288 |
+
verticle = [-1. / line[0], -1, point[1] - (-1 / line[0] * point[0])]
|
289 |
+
return verticle
|
290 |
+
|
291 |
+
|
292 |
+
def rectangle_from_parallelogram(poly):
|
293 |
+
'''
|
294 |
+
fit a rectangle from a parallelogram
|
295 |
+
:param poly:
|
296 |
+
:return:
|
297 |
+
'''
|
298 |
+
p0, p1, p2, p3 = poly
|
299 |
+
angle_p0 = np.arccos(np.dot(p1 - p0, p3 - p0) / (np.linalg.norm(p0 - p1) * np.linalg.norm(p3 - p0)))
|
300 |
+
if angle_p0 < 0.5 * np.pi:
|
301 |
+
if np.linalg.norm(p0 - p1) > np.linalg.norm(p0 - p3):
|
302 |
+
# p0 and p2
|
303 |
+
## p0
|
304 |
+
p2p3 = fit_line([p2[0], p3[0]], [p2[1], p3[1]])
|
305 |
+
p2p3_verticle = line_verticle(p2p3, p0)
|
306 |
+
|
307 |
+
new_p3 = line_cross_point(p2p3, p2p3_verticle)
|
308 |
+
## p2
|
309 |
+
p0p1 = fit_line([p0[0], p1[0]], [p0[1], p1[1]])
|
310 |
+
p0p1_verticle = line_verticle(p0p1, p2)
|
311 |
+
|
312 |
+
new_p1 = line_cross_point(p0p1, p0p1_verticle)
|
313 |
+
return np.array([p0, new_p1, p2, new_p3], dtype=np.float32)
|
314 |
+
else:
|
315 |
+
p1p2 = fit_line([p1[0], p2[0]], [p1[1], p2[1]])
|
316 |
+
p1p2_verticle = line_verticle(p1p2, p0)
|
317 |
+
|
318 |
+
new_p1 = line_cross_point(p1p2, p1p2_verticle)
|
319 |
+
p0p3 = fit_line([p0[0], p3[0]], [p0[1], p3[1]])
|
320 |
+
p0p3_verticle = line_verticle(p0p3, p2)
|
321 |
+
|
322 |
+
new_p3 = line_cross_point(p0p3, p0p3_verticle)
|
323 |
+
return np.array([p0, new_p1, p2, new_p3], dtype=np.float32)
|
324 |
+
else:
|
325 |
+
if np.linalg.norm(p0 - p1) > np.linalg.norm(p0 - p3):
|
326 |
+
# p1 and p3
|
327 |
+
## p1
|
328 |
+
p2p3 = fit_line([p2[0], p3[0]], [p2[1], p3[1]])
|
329 |
+
p2p3_verticle = line_verticle(p2p3, p1)
|
330 |
+
|
331 |
+
new_p2 = line_cross_point(p2p3, p2p3_verticle)
|
332 |
+
## p3
|
333 |
+
p0p1 = fit_line([p0[0], p1[0]], [p0[1], p1[1]])
|
334 |
+
p0p1_verticle = line_verticle(p0p1, p3)
|
335 |
+
|
336 |
+
new_p0 = line_cross_point(p0p1, p0p1_verticle)
|
337 |
+
return np.array([new_p0, p1, new_p2, p3], dtype=np.float32)
|
338 |
+
else:
|
339 |
+
p0p3 = fit_line([p0[0], p3[0]], [p0[1], p3[1]])
|
340 |
+
p0p3_verticle = line_verticle(p0p3, p1)
|
341 |
+
|
342 |
+
new_p0 = line_cross_point(p0p3, p0p3_verticle)
|
343 |
+
p1p2 = fit_line([p1[0], p2[0]], [p1[1], p2[1]])
|
344 |
+
p1p2_verticle = line_verticle(p1p2, p3)
|
345 |
+
|
346 |
+
new_p2 = line_cross_point(p1p2, p1p2_verticle)
|
347 |
+
return np.array([new_p0, p1, new_p2, p3], dtype=np.float32)
|
348 |
+
|
349 |
+
|
350 |
+
def sort_rectangle(poly):
|
351 |
+
# sort the four coordinates of the polygon, points in poly should be sorted clockwise
|
352 |
+
# First find the lowest point
|
353 |
+
p_lowest = np.argmax(poly[:, 1])
|
354 |
+
if np.count_nonzero(poly[:, 1] == poly[p_lowest, 1]) == 2:
|
355 |
+
# 底边平行于X轴, 那么p0为左上角 - if the bottom line is parallel to x-axis, then p0 must be the upper-left corner
|
356 |
+
p0_index = np.argmin(np.sum(poly, axis=1))
|
357 |
+
p1_index = (p0_index + 1) % 4
|
358 |
+
p2_index = (p0_index + 2) % 4
|
359 |
+
p3_index = (p0_index + 3) % 4
|
360 |
+
return poly[[p0_index, p1_index, p2_index, p3_index]], 0.
|
361 |
+
else:
|
362 |
+
# 找到最低点右边的点 - find the point that sits right to the lowest point
|
363 |
+
p_lowest_right = (p_lowest - 1) % 4
|
364 |
+
p_lowest_left = (p_lowest + 1) % 4
|
365 |
+
angle = np.arctan(
|
366 |
+
-(poly[p_lowest][1] - poly[p_lowest_right][1]) / (poly[p_lowest][0] - poly[p_lowest_right][0]))
|
367 |
+
# assert angle > 0
|
368 |
+
if angle <= 0:
|
369 |
+
print(angle, poly[p_lowest], poly[p_lowest_right])
|
370 |
+
if angle / np.pi * 180 > 45:
|
371 |
+
# 这个点为p2 - this point is p2
|
372 |
+
p2_index = p_lowest
|
373 |
+
p1_index = (p2_index - 1) % 4
|
374 |
+
p0_index = (p2_index - 2) % 4
|
375 |
+
p3_index = (p2_index + 1) % 4
|
376 |
+
return poly[[p0_index, p1_index, p2_index, p3_index]], -(np.pi / 2 - angle)
|
377 |
+
else:
|
378 |
+
# 这个点为p3 - this point is p3
|
379 |
+
p3_index = p_lowest
|
380 |
+
p0_index = (p3_index + 1) % 4
|
381 |
+
p1_index = (p3_index + 2) % 4
|
382 |
+
p2_index = (p3_index + 3) % 4
|
383 |
+
return poly[[p0_index, p1_index, p2_index, p3_index]], angle
|
384 |
+
|
385 |
+
|
386 |
+
def restore_rectangle_rbox(origin, geometry):
|
387 |
+
d = geometry[:, :4]
|
388 |
+
angle = geometry[:, 4]
|
389 |
+
# for angle > 0
|
390 |
+
origin_0 = origin[angle >= 0]
|
391 |
+
d_0 = d[angle >= 0]
|
392 |
+
angle_0 = angle[angle >= 0]
|
393 |
+
if origin_0.shape[0] > 0:
|
394 |
+
p = np.array([np.zeros(d_0.shape[0]), -d_0[:, 0] - d_0[:, 2],
|
395 |
+
d_0[:, 1] + d_0[:, 3], -d_0[:, 0] - d_0[:, 2],
|
396 |
+
d_0[:, 1] + d_0[:, 3], np.zeros(d_0.shape[0]),
|
397 |
+
np.zeros(d_0.shape[0]), np.zeros(d_0.shape[0]),
|
398 |
+
d_0[:, 3], -d_0[:, 2]])
|
399 |
+
p = p.transpose((1, 0)).reshape((-1, 5, 2)) # N*5*2
|
400 |
+
|
401 |
+
rotate_matrix_x = np.array([np.cos(angle_0), np.sin(angle_0)]).transpose((1, 0))
|
402 |
+
rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) # N*5*2
|
403 |
+
|
404 |
+
rotate_matrix_y = np.array([-np.sin(angle_0), np.cos(angle_0)]).transpose((1, 0))
|
405 |
+
rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1))
|
406 |
+
|
407 |
+
p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis] # N*5*1
|
408 |
+
p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis] # N*5*1
|
409 |
+
|
410 |
+
p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2) # N*5*2
|
411 |
+
|
412 |
+
p3_in_origin = origin_0 - p_rotate[:, 4, :]
|
413 |
+
new_p0 = p_rotate[:, 0, :] + p3_in_origin # N*2
|
414 |
+
new_p1 = p_rotate[:, 1, :] + p3_in_origin
|
415 |
+
new_p2 = p_rotate[:, 2, :] + p3_in_origin
|
416 |
+
new_p3 = p_rotate[:, 3, :] + p3_in_origin
|
417 |
+
|
418 |
+
new_p_0 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :],
|
419 |
+
new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1) # N*4*2
|
420 |
+
else:
|
421 |
+
new_p_0 = np.zeros((0, 4, 2))
|
422 |
+
# for angle < 0
|
423 |
+
origin_1 = origin[angle < 0]
|
424 |
+
d_1 = d[angle < 0]
|
425 |
+
angle_1 = angle[angle < 0]
|
426 |
+
if origin_1.shape[0] > 0:
|
427 |
+
p = np.array([-d_1[:, 1] - d_1[:, 3], -d_1[:, 0] - d_1[:, 2],
|
428 |
+
np.zeros(d_1.shape[0]), -d_1[:, 0] - d_1[:, 2],
|
429 |
+
np.zeros(d_1.shape[0]), np.zeros(d_1.shape[0]),
|
430 |
+
-d_1[:, 1] - d_1[:, 3], np.zeros(d_1.shape[0]),
|
431 |
+
-d_1[:, 1], -d_1[:, 2]])
|
432 |
+
p = p.transpose((1, 0)).reshape((-1, 5, 2)) # N*5*2
|
433 |
+
|
434 |
+
rotate_matrix_x = np.array([np.cos(-angle_1), -np.sin(-angle_1)]).transpose((1, 0))
|
435 |
+
rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) # N*5*2
|
436 |
+
|
437 |
+
rotate_matrix_y = np.array([np.sin(-angle_1), np.cos(-angle_1)]).transpose((1, 0))
|
438 |
+
rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1))
|
439 |
+
|
440 |
+
p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis] # N*5*1
|
441 |
+
p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis] # N*5*1
|
442 |
+
|
443 |
+
p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2) # N*5*2
|
444 |
+
|
445 |
+
p3_in_origin = origin_1 - p_rotate[:, 4, :]
|
446 |
+
new_p0 = p_rotate[:, 0, :] + p3_in_origin # N*2
|
447 |
+
new_p1 = p_rotate[:, 1, :] + p3_in_origin
|
448 |
+
new_p2 = p_rotate[:, 2, :] + p3_in_origin
|
449 |
+
new_p3 = p_rotate[:, 3, :] + p3_in_origin
|
450 |
+
|
451 |
+
new_p_1 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :],
|
452 |
+
new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1) # N*4*2
|
453 |
+
else:
|
454 |
+
new_p_1 = np.zeros((0, 4, 2))
|
455 |
+
return np.concatenate([new_p_0, new_p_1])
|
456 |
+
|
457 |
+
|
458 |
+
def restore_rectangle(origin, geometry):
|
459 |
+
return restore_rectangle_rbox(origin, geometry)
|
460 |
+
|
461 |
+
|
462 |
+
def generate_rbox(im_size, polys, tags):
|
463 |
+
h, w = im_size
|
464 |
+
poly_mask = np.zeros((h, w), dtype=np.uint8)
|
465 |
+
score_map = np.zeros((h, w), dtype=np.uint8)
|
466 |
+
geo_map = np.zeros((h, w, 5), dtype=np.float32)
|
467 |
+
# mask used during traning, to ignore some hard areas,用于忽略那些过小的文本
|
468 |
+
training_mask = np.ones((h, w), dtype=np.uint8)
|
469 |
+
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
470 |
+
poly = poly_tag[0]
|
471 |
+
tag = poly_tag[1]
|
472 |
+
|
473 |
+
# 对每个顶点,找到经过他的两条边中较短的那条
|
474 |
+
r = [None, None, None, None]
|
475 |
+
for i in range(4):
|
476 |
+
r[i] = min(np.linalg.norm(poly[i] - poly[(i + 1) % 4]),
|
477 |
+
np.linalg.norm(poly[i] - poly[(i - 1) % 4]))
|
478 |
+
# score map
|
479 |
+
# 放缩边框为之前的0.3倍,并对边框对应score图中的位置进行填充
|
480 |
+
shrinked_poly = shrink_poly(poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
|
481 |
+
cv2.fillPoly(score_map, shrinked_poly, 1)
|
482 |
+
cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
|
483 |
+
# if the poly is too small, then ignore it during training
|
484 |
+
# 如果文本框标签太小或者txt中没具体标记是什么内容,即*或者###,则加掩模,训练时忽略该部分
|
485 |
+
poly_h = min(np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2]))
|
486 |
+
poly_w = min(np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3]))
|
487 |
+
if min(poly_h, poly_w) < cfg.min_text_size:
|
488 |
+
cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
|
489 |
+
if tag:
|
490 |
+
cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
|
491 |
+
|
492 |
+
# 当前新加入的文本框区域像素点
|
493 |
+
xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
|
494 |
+
# if geometry == 'RBOX':
|
495 |
+
# 对任意两个顶点的组合生成一个平行四边形 - generate a parallelogram for any combination of two vertices
|
496 |
+
fitted_parallelograms = []
|
497 |
+
for i in range(4):
|
498 |
+
# 选中p0和p1的连线边,生成两个平行四边形
|
499 |
+
p0 = poly[i]
|
500 |
+
p1 = poly[(i + 1) % 4]
|
501 |
+
p2 = poly[(i + 2) % 4]
|
502 |
+
p3 = poly[(i + 3) % 4]
|
503 |
+
# 拟合ax+by+c=0
|
504 |
+
edge = fit_line([p0[0], p1[0]], [p0[1], p1[1]])
|
505 |
+
backward_edge = fit_line([p0[0], p3[0]], [p0[1], p3[1]])
|
506 |
+
forward_edge = fit_line([p1[0], p2[0]], [p1[1], p2[1]])
|
507 |
+
# 通过另外两个点距离edge的距离,来决定edge对应的平行线应该过p2还是p3
|
508 |
+
if point_dist_to_line(p0, p1, p2) > point_dist_to_line(p0, p1, p3):
|
509 |
+
# 平行线经过p2 - parallel lines through p2
|
510 |
+
if edge[1] == 0:
|
511 |
+
edge_opposite = [1, 0, -p2[0]]
|
512 |
+
else:
|
513 |
+
edge_opposite = [edge[0], -1, p2[1] - edge[0] * p2[0]]
|
514 |
+
else:
|
515 |
+
# 经过p3 - after p3
|
516 |
+
if edge[1] == 0:
|
517 |
+
edge_opposite = [1, 0, -p3[0]]
|
518 |
+
else:
|
519 |
+
edge_opposite = [edge[0], -1, p3[1] - edge[0] * p3[0]]
|
520 |
+
# move forward edge
|
521 |
+
new_p0 = p0
|
522 |
+
new_p1 = p1
|
523 |
+
new_p2 = p2
|
524 |
+
new_p3 = p3
|
525 |
+
new_p2 = line_cross_point(forward_edge, edge_opposite)
|
526 |
+
if point_dist_to_line(p1, new_p2, p0) > point_dist_to_line(p1, new_p2, p3):
|
527 |
+
# across p0
|
528 |
+
if forward_edge[1] == 0:
|
529 |
+
forward_opposite = [1, 0, -p0[0]]
|
530 |
+
else:
|
531 |
+
forward_opposite = [forward_edge[0], -1, p0[1] - forward_edge[0] * p0[0]]
|
532 |
+
else:
|
533 |
+
# across p3
|
534 |
+
if forward_edge[1] == 0:
|
535 |
+
forward_opposite = [1, 0, -p3[0]]
|
536 |
+
else:
|
537 |
+
forward_opposite = [forward_edge[0], -1, p3[1] - forward_edge[0] * p3[0]]
|
538 |
+
new_p0 = line_cross_point(forward_opposite, edge)
|
539 |
+
new_p3 = line_cross_point(forward_opposite, edge_opposite)
|
540 |
+
fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0])
|
541 |
+
# or move backward edge
|
542 |
+
new_p0 = p0
|
543 |
+
new_p1 = p1
|
544 |
+
new_p2 = p2
|
545 |
+
new_p3 = p3
|
546 |
+
new_p3 = line_cross_point(backward_edge, edge_opposite)
|
547 |
+
if point_dist_to_line(p0, p3, p1) > point_dist_to_line(p0, p3, p2):
|
548 |
+
# across p1
|
549 |
+
if backward_edge[1] == 0:
|
550 |
+
backward_opposite = [1, 0, -p1[0]]
|
551 |
+
else:
|
552 |
+
backward_opposite = [backward_edge[0], -1, p1[1] - backward_edge[0] * p1[0]]
|
553 |
+
else:
|
554 |
+
# across p2
|
555 |
+
if backward_edge[1] == 0:
|
556 |
+
backward_opposite = [1, 0, -p2[0]]
|
557 |
+
else:
|
558 |
+
backward_opposite = [backward_edge[0], -1, p2[1] - backward_edge[0] * p2[0]]
|
559 |
+
new_p1 = line_cross_point(backward_opposite, edge)
|
560 |
+
new_p2 = line_cross_point(backward_opposite, edge_opposite)
|
561 |
+
fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0])
|
562 |
+
|
563 |
+
# 选定面积最小的平行四边形
|
564 |
+
areas = [Polygon(t).area for t in fitted_parallelograms]
|
565 |
+
parallelogram = np.array(fitted_parallelograms[np.argmin(areas)][:-1], dtype=np.float32)
|
566 |
+
# sort thie polygon
|
567 |
+
parallelogram_coord_sum = np.sum(parallelogram, axis=1)
|
568 |
+
min_coord_idx = np.argmin(parallelogram_coord_sum)
|
569 |
+
parallelogram = parallelogram[
|
570 |
+
[min_coord_idx, (min_coord_idx + 1) % 4, (min_coord_idx + 2) % 4, (min_coord_idx + 3) % 4]]
|
571 |
+
|
572 |
+
# 得到外包矩形即旋转角
|
573 |
+
rectange = rectangle_from_parallelogram(parallelogram)
|
574 |
+
rectange, rotate_angle = sort_rectangle(rectange)
|
575 |
+
|
576 |
+
p0_rect, p1_rect, p2_rect, p3_rect = rectange
|
577 |
+
# 对当前新加入的文本框区域像素点,根据其到矩形四边的距离修改geo_map
|
578 |
+
for y, x in xy_in_poly:
|
579 |
+
point = np.array([x, y], dtype=np.float32)
|
580 |
+
# top
|
581 |
+
geo_map[y, x, 0] = point_dist_to_line(p0_rect, p1_rect, point)
|
582 |
+
# right
|
583 |
+
geo_map[y, x, 1] = point_dist_to_line(p1_rect, p2_rect, point)
|
584 |
+
# down
|
585 |
+
geo_map[y, x, 2] = point_dist_to_line(p2_rect, p3_rect, point)
|
586 |
+
# left
|
587 |
+
geo_map[y, x, 3] = point_dist_to_line(p3_rect, p0_rect, point)
|
588 |
+
# angle
|
589 |
+
geo_map[y, x, 4] = rotate_angle
|
590 |
+
return score_map, geo_map, training_mask
|
591 |
+
|
592 |
+
|
593 |
+
def generator(index,
|
594 |
+
input_size=512,
|
595 |
+
background_ratio=3. / 8, # 纯背景样本比例
|
596 |
+
random_scale=np.array([0.5, 1, 2.0, 3.0]), # 提取多尺度图片信息
|
597 |
+
image_list=None):
|
598 |
+
try:
|
599 |
+
im_fn = image_list[index]
|
600 |
+
im = cv2.imread(im_fn)
|
601 |
+
if im is None:
|
602 |
+
print("can't find image")
|
603 |
+
return None, None, None, None, None
|
604 |
+
h, w, _ = im.shape
|
605 |
+
# 所以要把gt去掉
|
606 |
+
txt_fn = im_fn.replace(os.path.basename(im_fn).split('.')[1], 'txt')
|
607 |
+
if not os.path.exists(txt_fn):
|
608 |
+
print('text file {} does not exists'.format(txt_fn))
|
609 |
+
return None, None, None, None, None
|
610 |
+
# 加载标注框信息
|
611 |
+
text_polys, text_tags = load_annoataion(txt_fn)
|
612 |
+
|
613 |
+
text_polys, text_tags = check_and_validate_polys(text_polys, text_tags, (h, w))
|
614 |
+
|
615 |
+
# random scale this image,随机选择一种尺度
|
616 |
+
rd_scale = np.random.choice(random_scale)
|
617 |
+
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
618 |
+
text_polys *= rd_scale
|
619 |
+
|
620 |
+
# random crop a area from image,3/8���选中的概率,裁剪纯背景的图片
|
621 |
+
if np.random.rand() < background_ratio:
|
622 |
+
# crop background
|
623 |
+
im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=True)
|
624 |
+
if text_polys.shape[0] > 0:
|
625 |
+
# print("cannot find background")
|
626 |
+
return None, None, None, None, None
|
627 |
+
# pad and resize image
|
628 |
+
new_h, new_w, _ = im.shape
|
629 |
+
max_h_w_i = np.max([new_h, new_w, input_size])
|
630 |
+
im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
|
631 |
+
im_padded[:new_h, :new_w, :] = im.copy()
|
632 |
+
# 将裁剪后图片扩充成512*512的图片
|
633 |
+
im = cv2.resize(im_padded, dsize=(input_size, input_size))
|
634 |
+
score_map = np.zeros((input_size, input_size), dtype=np.uint8)
|
635 |
+
geo_map_channels = 5 if cfg.geometry == 'RBOX' else 8
|
636 |
+
geo_map = np.zeros((input_size, input_size, geo_map_channels), dtype=np.float32)
|
637 |
+
training_mask = np.ones((input_size, input_size), dtype=np.uint8)
|
638 |
+
else:
|
639 |
+
# 5 / 8的选中的概率,裁剪含文本信息的图片
|
640 |
+
im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=False)
|
641 |
+
if text_polys.shape[0] == 0:
|
642 |
+
# print("cannot find txt ground")
|
643 |
+
return None, None, None, None, None
|
644 |
+
h, w, _ = im.shape
|
645 |
+
# pad the image to the training input size or the longer side of image
|
646 |
+
new_h, new_w, _ = im.shape
|
647 |
+
max_h_w_i = np.max([new_h, new_w, input_size])
|
648 |
+
im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
|
649 |
+
im_padded[:new_h, :new_w, :] = im.copy()
|
650 |
+
im = im_padded
|
651 |
+
# resize the image to input size
|
652 |
+
# 填充,resize图像至设定尺寸
|
653 |
+
new_h, new_w, _ = im.shape
|
654 |
+
resize_h = input_size
|
655 |
+
resize_w = input_size
|
656 |
+
im = cv2.resize(im, dsize=(resize_w, resize_h))
|
657 |
+
# 将文本框坐标标签等比例修改
|
658 |
+
resize_ratio_3_x = resize_w / float(new_w)
|
659 |
+
resize_ratio_3_y = resize_h / float(new_h)
|
660 |
+
text_polys[:, :, 0] *= resize_ratio_3_x
|
661 |
+
text_polys[:, :, 1] *= resize_ratio_3_y
|
662 |
+
new_h, new_w, _ = im.shape
|
663 |
+
score_map, geo_map, training_mask = generate_rbox((new_h, new_w), text_polys, text_tags)
|
664 |
+
|
665 |
+
# 将一个样本的样本内容和标签信息append
|
666 |
+
images = im[:,:,::-1].astype(np.float32)
|
667 |
+
# 文件名加入列表
|
668 |
+
image_fns = im_fn
|
669 |
+
# 512*512取提取四分之一行列
|
670 |
+
score_maps = score_map[::4, ::4, np.newaxis].astype(np.float32)
|
671 |
+
geo_maps = geo_map[::4, ::4, :].astype(np.float32)
|
672 |
+
training_masks = training_mask[::4, ::4, np.newaxis].astype(np.float32)
|
673 |
+
# 符合一个样本之后输出
|
674 |
+
return images, image_fns, score_maps, geo_maps, training_masks
|
675 |
+
|
676 |
+
except Exception as e:
|
677 |
+
import traceback
|
678 |
+
traceback.print_exc()
|
679 |
+
|
680 |
+
# print("Exception is exist!")
|
681 |
+
return None, None, None, None, None
|
IndicPhotoOCR/detection/east_utils.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
from torch.nn import init
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import time
|
8 |
+
import requests
|
9 |
+
|
10 |
+
from IndicPhotoOCR.detection import east_config as cfg
|
11 |
+
from IndicPhotoOCR.detection import east_preprossing as preprossing
|
12 |
+
from IndicPhotoOCR.detection import east_locality_aware_nms as locality_aware_nms
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
# Example usage:
|
17 |
+
model_info = {
|
18 |
+
"east": {
|
19 |
+
"paths": [ cfg.checkpoint, cfg.pretrained_basemodel_path],
|
20 |
+
"urls" : ["https://github.com/anikde/STocr/releases/download/e0.1.0/epoch_990_checkpoint.pth.tar", "https://github.com/anikde/STocr/releases/download/e0.1.0/mobilenet_v2.pth.tar"]
|
21 |
+
},
|
22 |
+
}
|
23 |
+
|
24 |
+
class ModelManager:
|
25 |
+
def __init__(self):
|
26 |
+
# self.root_model_dir = "bharatOCR/detection/"
|
27 |
+
pass
|
28 |
+
|
29 |
+
def download_model(self, url, path):
|
30 |
+
response = requests.get(url, stream=True)
|
31 |
+
if response.status_code == 200:
|
32 |
+
with open(path, 'wb') as f:
|
33 |
+
for chunk in response.iter_content(chunk_size=8192):
|
34 |
+
if chunk: # Filter out keep-alive chunks
|
35 |
+
f.write(chunk)
|
36 |
+
print(f"Downloaded: {path}")
|
37 |
+
else:
|
38 |
+
print(f"Failed to download from {url}")
|
39 |
+
|
40 |
+
def ensure_model(self, model_name):
|
41 |
+
model_paths = model_info[model_name]["paths"] # Changed to handle multiple paths
|
42 |
+
urls = model_info[model_name]["urls"] # Changed to handle multiple URLs
|
43 |
+
|
44 |
+
|
45 |
+
for model_path, url in zip(model_paths, urls):
|
46 |
+
# full_model_path = os.path.join(self.root_model_dir, model_path)
|
47 |
+
|
48 |
+
# Ensure the model path directory exists
|
49 |
+
os.makedirs(os.path.dirname(os.path.join(*cfg.pretrained_basemodel_path.split("/"))), exist_ok=True)
|
50 |
+
|
51 |
+
if not os.path.exists(model_path):
|
52 |
+
print(f"Model not found locally. Downloading {model_name} from {url}...")
|
53 |
+
self.download_model(url, model_path)
|
54 |
+
else:
|
55 |
+
print(f"Model already exists at {model_path}. No need to download.")
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
# # Initialize ModelManager and ensure Hindi models are downloaded
|
60 |
+
model_manager = ModelManager()
|
61 |
+
model_manager.ensure_model("east")
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
def init_weights(m_list, init_type=cfg.init_type, gain=0.02):
|
66 |
+
print("EAST <==> Prepare <==> Init Network'{}' <==> Begin".format(cfg.init_type))
|
67 |
+
# this will apply to each layer
|
68 |
+
for m in m_list:
|
69 |
+
classname = m.__class__.__name__
|
70 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
71 |
+
if init_type == 'normal':
|
72 |
+
init.normal_(m.weight.data, 0.0, gain)
|
73 |
+
elif init_type == 'xavier':
|
74 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
75 |
+
elif init_type == 'kaiming':
|
76 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # good for relu
|
77 |
+
elif init_type == 'orthogonal':
|
78 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
79 |
+
else:
|
80 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
81 |
+
|
82 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
83 |
+
init.constant_(m.bias.data, 0.0)
|
84 |
+
elif classname.find('BatchNorm2d') != -1:
|
85 |
+
init.normal_(m.weight.data, 1.0, gain)
|
86 |
+
init.constant_(m.bias.data, 0.0)
|
87 |
+
|
88 |
+
print("EAST <==> Prepare <==> Init Network'{}' <==> Done".format(cfg.init_type))
|
89 |
+
|
90 |
+
|
91 |
+
def Loading_checkpoint(model, optimizer, scheduler, filename='checkpoint.pth.tar'):
|
92 |
+
"""[summary]
|
93 |
+
[description]
|
94 |
+
Arguments:
|
95 |
+
state {[type]} -- [description] a dict describe some params
|
96 |
+
Keyword Arguments:
|
97 |
+
filename {str} -- [description] (default: {'checkpoint.pth.tar'})
|
98 |
+
"""
|
99 |
+
weightpath = os.path.abspath(cfg.checkpoint)
|
100 |
+
print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(weightpath))
|
101 |
+
checkpoint = torch.load(weightpath)
|
102 |
+
start_epoch = checkpoint['epoch'] + 1
|
103 |
+
model.load_state_dict(checkpoint['state_dict'])
|
104 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
105 |
+
scheduler.load_state_dict(checkpoint['scheduler'])
|
106 |
+
print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(weightpath))
|
107 |
+
|
108 |
+
return start_epoch
|
109 |
+
|
110 |
+
|
111 |
+
def save_checkpoint(epoch, model, optimizer, scheduler, filename='checkpoint.pth.tar'):
|
112 |
+
"""[summary]
|
113 |
+
[description]
|
114 |
+
Arguments:
|
115 |
+
state {[type]} -- [description] a dict describe some params
|
116 |
+
Keyword Arguments:
|
117 |
+
filename {str} -- [description] (default: {'checkpoint.pth.tar'})
|
118 |
+
"""
|
119 |
+
print('EAST <==> Save weight - epoch {} <==> Begin'.format(epoch))
|
120 |
+
state = {
|
121 |
+
'epoch': epoch,
|
122 |
+
'state_dict': model.state_dict(),
|
123 |
+
'optimizer': optimizer.state_dict(),
|
124 |
+
'scheduler': scheduler.state_dict()
|
125 |
+
}
|
126 |
+
weight_dir = cfg.save_model_path
|
127 |
+
if not os.path.exists(weight_dir):
|
128 |
+
os.mkdir(weight_dir)
|
129 |
+
filename = 'epoch_' + str(epoch) + '_checkpoint.pth.tar'
|
130 |
+
file_path = os.path.join(weight_dir, filename)
|
131 |
+
torch.save(state, file_path)
|
132 |
+
print('EAST <==> Save weight - epoch {} <==> Done'.format(epoch))
|
133 |
+
|
134 |
+
|
135 |
+
class Regularization(torch.nn.Module):
|
136 |
+
def __init__(self, model, weight_decay, p=2):
|
137 |
+
super(Regularization, self).__init__()
|
138 |
+
if weight_decay < 0:
|
139 |
+
print("param weight_decay can not <0")
|
140 |
+
exit(0)
|
141 |
+
self.model = model
|
142 |
+
self.weight_decay = weight_decay
|
143 |
+
self.p = p
|
144 |
+
self.weight_list = self.get_weight(model)
|
145 |
+
# self.weight_info(self.weight_list)
|
146 |
+
|
147 |
+
def to(self, device):
|
148 |
+
self.device = device
|
149 |
+
super().to(device)
|
150 |
+
return self
|
151 |
+
|
152 |
+
def forward(self, model):
|
153 |
+
self.weight_list = self.get_weight(model) # 获得最新的权重
|
154 |
+
reg_loss = self.regularization_loss(self.weight_list, self.weight_decay, p=self.p)
|
155 |
+
return reg_loss
|
156 |
+
|
157 |
+
def get_weight(self, model):
|
158 |
+
weight_list = []
|
159 |
+
for name, param in model.named_parameters():
|
160 |
+
if 'weight' in name:
|
161 |
+
weight = (name, param)
|
162 |
+
weight_list.append(weight)
|
163 |
+
return weight_list
|
164 |
+
|
165 |
+
def regularization_loss(self, weight_list, weight_decay, p=2):
|
166 |
+
reg_loss = 0
|
167 |
+
for name, w in weight_list:
|
168 |
+
l2_reg = torch.norm(w, p=p)
|
169 |
+
reg_loss = reg_loss + l2_reg
|
170 |
+
|
171 |
+
reg_loss = weight_decay * reg_loss
|
172 |
+
return reg_loss
|
173 |
+
|
174 |
+
def weight_info(self, weight_list):
|
175 |
+
print("---------------regularization weight---------------")
|
176 |
+
for name, w in weight_list:
|
177 |
+
print(name)
|
178 |
+
print("---------------------------------------------------")
|
179 |
+
|
180 |
+
|
181 |
+
def resize_image(im, max_side_len=2400):
|
182 |
+
'''
|
183 |
+
resize image to a size multiple of 32 which is required by the network
|
184 |
+
:param im: the resized image
|
185 |
+
:param max_side_len: limit of max image size to avoid out of memory in gpu
|
186 |
+
:return: the resized image and the resize ratio
|
187 |
+
'''
|
188 |
+
h, w, _ = im.shape
|
189 |
+
|
190 |
+
resize_w = w
|
191 |
+
resize_h = h
|
192 |
+
|
193 |
+
# limit the max side
|
194 |
+
"""
|
195 |
+
if max(resize_h, resize_w) > max_side_len:
|
196 |
+
ratio = float(max_side_len) / resize_h if resize_h > resize_w else float(max_side_len) / resize_w
|
197 |
+
else:
|
198 |
+
ratio = 1.
|
199 |
+
|
200 |
+
resize_h = int(resize_h * ratio)
|
201 |
+
resize_w = int(resize_w * ratio)
|
202 |
+
"""
|
203 |
+
|
204 |
+
resize_h = resize_h if resize_h % 32 == 0 else (resize_h // 32 - 1) * 32
|
205 |
+
resize_w = resize_w if resize_w % 32 == 0 else (resize_w // 32 - 1) * 32
|
206 |
+
#resize_h, resize_w = 512, 512
|
207 |
+
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
208 |
+
|
209 |
+
ratio_h = resize_h / float(h)
|
210 |
+
ratio_w = resize_w / float(w)
|
211 |
+
|
212 |
+
return im, (ratio_h, ratio_w)
|
213 |
+
|
214 |
+
|
215 |
+
def detect(score_map, geo_map, timer, score_map_thresh=0.8, box_thresh=0.1, nms_thres=0.2):
|
216 |
+
'''
|
217 |
+
restore text boxes from score map and geo map
|
218 |
+
:param score_map:
|
219 |
+
:param geo_map:
|
220 |
+
:param timer:
|
221 |
+
:param score_map_thresh: threshhold for score map
|
222 |
+
:param box_thresh: threshhold for boxes
|
223 |
+
:param nms_thres: threshold for nms
|
224 |
+
:return:
|
225 |
+
'''
|
226 |
+
|
227 |
+
# score_map 和 geo_map 的维数进行调整
|
228 |
+
if len(score_map.shape) == 4:
|
229 |
+
score_map = score_map[0, :, :, 0]
|
230 |
+
geo_map = geo_map[0, :, :, :]
|
231 |
+
# filter the score map
|
232 |
+
xy_text = np.argwhere(score_map > score_map_thresh)
|
233 |
+
# sort the text boxes via the y axis
|
234 |
+
xy_text = xy_text[np.argsort(xy_text[:, 0])]
|
235 |
+
# restore
|
236 |
+
start = time.time()
|
237 |
+
text_box_restored = preprossing.restore_rectangle(xy_text[:, ::-1] * 4,
|
238 |
+
geo_map[xy_text[:, 0], xy_text[:, 1], :]) # N*4*2
|
239 |
+
# print('{} text boxes before nms'.format(text_box_restored.shape[0]))
|
240 |
+
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
|
241 |
+
boxes[:, :8] = text_box_restored.reshape((-1, 8))
|
242 |
+
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
|
243 |
+
timer['restore'] = time.time() - start
|
244 |
+
# nms part
|
245 |
+
start = time.time()
|
246 |
+
boxes = locality_aware_nms.nms_locality(boxes.astype(np.float64), nms_thres)
|
247 |
+
timer['nms'] = time.time() - start
|
248 |
+
# print(timer['nms'])
|
249 |
+
if boxes.shape[0] == 0:
|
250 |
+
return None, timer
|
251 |
+
|
252 |
+
# here we filter some low score boxes by the average score map, this is different from the orginal paper
|
253 |
+
for i, box in enumerate(boxes):
|
254 |
+
mask = np.zeros_like(score_map, dtype=np.uint8)
|
255 |
+
cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1)
|
256 |
+
boxes[i, 8] = cv2.mean(score_map, mask)[0]
|
257 |
+
boxes = boxes[boxes[:, 8] > box_thresh]
|
258 |
+
return boxes, timer
|
259 |
+
|
260 |
+
|
261 |
+
def sort_poly(p):
|
262 |
+
min_axis = np.argmin(np.sum(p, axis=1))
|
263 |
+
p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]]
|
264 |
+
if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
|
265 |
+
return p
|
266 |
+
else:
|
267 |
+
return p[[0, 3, 2, 1]]
|
268 |
+
|
269 |
+
|
270 |
+
def mean_image_subtraction(images, means=cfg.means):
|
271 |
+
'''
|
272 |
+
image normalization
|
273 |
+
:param images: bs * w * h * channel
|
274 |
+
:param means:
|
275 |
+
:return:
|
276 |
+
'''
|
277 |
+
num_channels = images.data.shape[1]
|
278 |
+
if len(means) != num_channels:
|
279 |
+
raise ValueError('len(means) must match the number of channels')
|
280 |
+
for i in range(num_channels):
|
281 |
+
images.data[:, i, :, :] -= means[i]
|
282 |
+
|
283 |
+
return images
|
IndicPhotoOCR/ocr.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
from IndicPhotoOCR.detection.east_detector import EASTdetector
|
10 |
+
from IndicPhotoOCR.script_identification.CLIP_identifier import CLIPidentifier
|
11 |
+
from IndicPhotoOCR.recognition.parseq_recogniser import PARseqrecogniser
|
12 |
+
import IndicPhotoOCR.detection.east_config as cfg
|
13 |
+
|
14 |
+
|
15 |
+
class OCR:
|
16 |
+
def __init__(self, device='cuda:0', verbose=False):
|
17 |
+
# self.detect_model_checkpoint = detect_model_checkpoint
|
18 |
+
self.device = device
|
19 |
+
self.verbose = verbose
|
20 |
+
# self.image_path = image_path
|
21 |
+
self.detector = EASTdetector()
|
22 |
+
self.recogniser = PARseqrecogniser()
|
23 |
+
self.identifier = CLIPidentifier()
|
24 |
+
|
25 |
+
def detect(self, image_path, detect_model_checkpoint=cfg.checkpoint):
|
26 |
+
"""Run the detection model to get bounding boxes of text areas."""
|
27 |
+
|
28 |
+
if self.verbose:
|
29 |
+
print("Running text detection...")
|
30 |
+
detections = self.detector.detect(image_path, detect_model_checkpoint, self.device)
|
31 |
+
# print(detections)
|
32 |
+
return detections['detections']
|
33 |
+
|
34 |
+
def visualize_detection(self, image_path, detections, save_path=None, show=False):
|
35 |
+
# Default save path if none is provided
|
36 |
+
default_save_path = "test.png"
|
37 |
+
path_to_save = save_path if save_path is not None else default_save_path
|
38 |
+
|
39 |
+
# Get the directory part of the path
|
40 |
+
directory = os.path.dirname(path_to_save)
|
41 |
+
|
42 |
+
# Check if the directory exists, and create it if it doesn’t
|
43 |
+
if directory and not os.path.exists(directory):
|
44 |
+
os.makedirs(directory)
|
45 |
+
print(f"Created directory: {directory}")
|
46 |
+
|
47 |
+
# Read the image and draw bounding boxes
|
48 |
+
image = cv2.imread(image_path)
|
49 |
+
for box in detections:
|
50 |
+
# Convert list of points to a numpy array with int type
|
51 |
+
points = np.array(box, np.int32)
|
52 |
+
points = points.reshape((-1, 1, 2)) # Reshape for cv2.polylines
|
53 |
+
# Draw the polygon
|
54 |
+
cv2.polylines(image, [points], isClosed=True, color=(0, 255, 0), thickness=3)
|
55 |
+
|
56 |
+
# Show the image if 'show' is True
|
57 |
+
if show:
|
58 |
+
plt.figure(figsize=(10, 10))
|
59 |
+
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
60 |
+
plt.axis("off")
|
61 |
+
plt.show()
|
62 |
+
|
63 |
+
# Save the annotated image
|
64 |
+
cv2.imwrite(path_to_save, image)
|
65 |
+
print(f"Image saved at: {path_to_save}")
|
66 |
+
|
67 |
+
def crop_and_identify_script(self, image, bbox):
|
68 |
+
"""
|
69 |
+
Crop a text area from the image and identify its script language.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
image (PIL.Image): The full image.
|
73 |
+
bbox (list): List of four corner points, each a [x, y] pair.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
str: Identified script language.
|
77 |
+
"""
|
78 |
+
# Extract x and y coordinates from the four corner points
|
79 |
+
x_coords = [point[0] for point in bbox]
|
80 |
+
y_coords = [point[1] for point in bbox]
|
81 |
+
|
82 |
+
# Get the bounding box coordinates (min and max)
|
83 |
+
x_min, y_min = min(x_coords), min(y_coords)
|
84 |
+
x_max, y_max = max(x_coords), max(y_coords)
|
85 |
+
|
86 |
+
# Crop the image based on the bounding box
|
87 |
+
cropped_image = image.crop((x_min, y_min, x_max, y_max))
|
88 |
+
root_image_dir = "IndicPhotoOCR/script_identification"
|
89 |
+
os.makedirs(f"{root_image_dir}/images", exist_ok=True)
|
90 |
+
# Temporarily save the cropped image to pass to the script model
|
91 |
+
cropped_path = f'{root_image_dir}/images/temp_crop_{x_min}_{y_min}.jpg'
|
92 |
+
cropped_image.save(cropped_path)
|
93 |
+
|
94 |
+
# Predict script language, here we assume "hindi" as the model name
|
95 |
+
if self.verbose:
|
96 |
+
print("Identifying script for the cropped area...")
|
97 |
+
script_lang = self.identifier.identify(cropped_path, "hindi") # Use "hindi" as the model name
|
98 |
+
# print(script_lang)
|
99 |
+
|
100 |
+
# Clean up temporary file
|
101 |
+
# os.remove(cropped_path)
|
102 |
+
|
103 |
+
return script_lang, cropped_path
|
104 |
+
|
105 |
+
def recognise(self, cropped_image_path, script_lang):
|
106 |
+
"""Recognize text in a cropped image area using the identified script."""
|
107 |
+
if self.verbose:
|
108 |
+
print("Recognizing text in detected area...")
|
109 |
+
recognized_text = self.recogniser.recognise(script_lang, cropped_image_path, script_lang, self.verbose)
|
110 |
+
# print(recognized_text)
|
111 |
+
return recognized_text
|
112 |
+
|
113 |
+
def ocr(self, image_path):
|
114 |
+
"""Process the image by detecting text areas, identifying script, and recognizing text."""
|
115 |
+
recognized_words = []
|
116 |
+
image = Image.open(image_path)
|
117 |
+
|
118 |
+
# Run detection
|
119 |
+
detections = self.detect(image_path)
|
120 |
+
|
121 |
+
# Process each detected text area
|
122 |
+
for bbox in detections:
|
123 |
+
# Crop and identify script language
|
124 |
+
script_lang, cropped_path = self.crop_and_identify_script(image, bbox)
|
125 |
+
|
126 |
+
# Check if the script language is valid
|
127 |
+
if script_lang:
|
128 |
+
|
129 |
+
# Recognize text
|
130 |
+
recognized_word = self.recognise(cropped_path, script_lang)
|
131 |
+
recognized_words.append(recognized_word)
|
132 |
+
|
133 |
+
if self.verbose:
|
134 |
+
print(f"Recognized word: {recognized_word}")
|
135 |
+
|
136 |
+
return recognized_words
|
137 |
+
|
138 |
+
if __name__ == '__main__':
|
139 |
+
# detect_model_checkpoint = 'bharatSTR/East/tmp/epoch_990_checkpoint.pth.tar'
|
140 |
+
sample_image_path = 'test_images/image_141.jpg'
|
141 |
+
cropped_image_path = 'test_images/cropped_image/image_141_0.jpg'
|
142 |
+
|
143 |
+
ocr = OCR(device="cpu", verbose=False)
|
144 |
+
|
145 |
+
# detections = ocr.detect(sample_image_path)
|
146 |
+
# print(detections)
|
147 |
+
|
148 |
+
# ocr.visualize_detection(sample_image_path, detections)
|
149 |
+
|
150 |
+
# recognition = ocr.recognise(cropped_image_path, "hindi")
|
151 |
+
# print(recognition)
|
152 |
+
|
153 |
+
recognised_words = ocr.ocr(sample_image_path)
|
154 |
+
print(recognised_words)
|
IndicPhotoOCR/recognition/__init__.py
ADDED
File without changes
|
IndicPhotoOCR/recognition/parseq_recogniser.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
# import fire
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
# import pandas as pd
|
7 |
+
import sys
|
8 |
+
import torch
|
9 |
+
import requests
|
10 |
+
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from PIL import Image
|
13 |
+
from nltk import edit_distance
|
14 |
+
from torchvision import transforms as T
|
15 |
+
from typing import Optional, Callable, Sequence, Tuple
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
|
19 |
+
from IndicPhotoOCR.utils.strhub.data.module import SceneTextDataModule
|
20 |
+
from IndicPhotoOCR.utils.strhub.models.utils import load_from_checkpoint
|
21 |
+
|
22 |
+
|
23 |
+
model_info = {
|
24 |
+
"assamese": {
|
25 |
+
"path": "models/assamese.ckpt",
|
26 |
+
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/assamese.ckpt",
|
27 |
+
},
|
28 |
+
"bengali": {
|
29 |
+
"path": "models/bengali.ckpt",
|
30 |
+
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/bengali.ckpt",
|
31 |
+
},
|
32 |
+
"hindi": {
|
33 |
+
"path": "models/hindi.ckpt",
|
34 |
+
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/hindi.ckpt",
|
35 |
+
},
|
36 |
+
"gujarati": {
|
37 |
+
"path": "models/gujarati.ckpt",
|
38 |
+
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/gujarati.ckpt",
|
39 |
+
},
|
40 |
+
"marathi": {
|
41 |
+
"path": "models/marathi.ckpt",
|
42 |
+
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/marathi.ckpt",
|
43 |
+
},
|
44 |
+
"odia": {
|
45 |
+
"path": "models/odia.ckpt",
|
46 |
+
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/odia.ckpt",
|
47 |
+
},
|
48 |
+
"punjabi": {
|
49 |
+
"path": "models/punjabi.ckpt",
|
50 |
+
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/punjabi.ckpt",
|
51 |
+
},
|
52 |
+
"tamil": {
|
53 |
+
"path": "models/tamil.ckpt",
|
54 |
+
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/tamil.ckpt",
|
55 |
+
},
|
56 |
+
"telugu": {
|
57 |
+
"path": "models/telugu.ckpt",
|
58 |
+
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/telugu.ckpt",
|
59 |
+
}
|
60 |
+
}
|
61 |
+
|
62 |
+
class PARseqrecogniser:
|
63 |
+
def __init__(self):
|
64 |
+
pass
|
65 |
+
|
66 |
+
def get_transform(self, img_size: Tuple[int], augment: bool = False, rotation: int = 0):
|
67 |
+
transforms = []
|
68 |
+
if augment:
|
69 |
+
from .augment import rand_augment_transform
|
70 |
+
transforms.append(rand_augment_transform())
|
71 |
+
if rotation:
|
72 |
+
transforms.append(lambda img: img.rotate(rotation, expand=True))
|
73 |
+
transforms.extend([
|
74 |
+
T.Resize(img_size, T.InterpolationMode.BICUBIC),
|
75 |
+
T.ToTensor(),
|
76 |
+
T.Normalize(0.5, 0.5)
|
77 |
+
])
|
78 |
+
return T.Compose(transforms)
|
79 |
+
|
80 |
+
|
81 |
+
def load_model(self, device, checkpoint):
|
82 |
+
model = load_from_checkpoint(checkpoint).eval().to(device)
|
83 |
+
return model
|
84 |
+
|
85 |
+
def get_model_output(self, device, model, image_path):
|
86 |
+
hp = model.hparams
|
87 |
+
transform = self.get_transform(hp.img_size, rotation=0)
|
88 |
+
|
89 |
+
image_name = image_path.split("/")[-1]
|
90 |
+
img = Image.open(image_path).convert('RGB')
|
91 |
+
img = transform(img)
|
92 |
+
logits = model(img.unsqueeze(0).to(device))
|
93 |
+
probs = logits.softmax(-1)
|
94 |
+
preds, probs = model.tokenizer.decode(probs)
|
95 |
+
text = model.charset_adapter(preds[0])
|
96 |
+
scores = probs[0].detach().cpu().numpy()
|
97 |
+
|
98 |
+
return text
|
99 |
+
|
100 |
+
# Ensure model file exists; download directly if not
|
101 |
+
def ensure_model(self, model_name):
|
102 |
+
model_path = model_info[model_name]["path"]
|
103 |
+
url = model_info[model_name]["url"]
|
104 |
+
root_model_dir = "IndicPhotoOCR/recognition/"
|
105 |
+
model_path = os.path.join(root_model_dir, model_path)
|
106 |
+
|
107 |
+
if not os.path.exists(model_path):
|
108 |
+
print(f"Model not found locally. Downloading {model_name} from {url}...")
|
109 |
+
|
110 |
+
# Start the download with a progress bar
|
111 |
+
response = requests.get(url, stream=True)
|
112 |
+
total_size = int(response.headers.get('content-length', 0))
|
113 |
+
os.makedirs(f"{root_model_dir}/models", exist_ok=True)
|
114 |
+
|
115 |
+
with open(model_path, "wb") as f, tqdm(
|
116 |
+
desc=model_name,
|
117 |
+
total=total_size,
|
118 |
+
unit='B',
|
119 |
+
unit_scale=True,
|
120 |
+
unit_divisor=1024,
|
121 |
+
) as bar:
|
122 |
+
for data in response.iter_content(chunk_size=1024):
|
123 |
+
f.write(data)
|
124 |
+
bar.update(len(data))
|
125 |
+
|
126 |
+
print(f"Downloaded model for {model_name}.")
|
127 |
+
|
128 |
+
return model_path
|
129 |
+
|
130 |
+
def bstr(checkpoint, language, image_dir, save_dir):
|
131 |
+
"""
|
132 |
+
Runs the OCR model to process images and save the output as a JSON file.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
checkpoint (str): Path to the model checkpoint file.
|
136 |
+
language (str): Language code (e.g., 'hindi', 'english').
|
137 |
+
image_dir (str): Directory containing the images to process.
|
138 |
+
save_dir (str): Directory where the output JSON file will be saved.
|
139 |
+
|
140 |
+
Example usage:
|
141 |
+
python your_script.py --checkpoint /path/to/checkpoint.ckpt --language hindi --image_dir /path/to/images --save_dir /path/to/save
|
142 |
+
"""
|
143 |
+
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
144 |
+
|
145 |
+
if language != "english":
|
146 |
+
model = load_model(device, checkpoint)
|
147 |
+
else:
|
148 |
+
model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().to(device)
|
149 |
+
|
150 |
+
parseq_dict = {}
|
151 |
+
for image_path in tqdm(os.listdir(image_dir)):
|
152 |
+
assert os.path.exists(os.path.join(image_dir, image_path)) == True, f"{image_path}"
|
153 |
+
text = get_model_output(device, model, os.path.join(image_dir, image_path), language=f"{language}")
|
154 |
+
|
155 |
+
filename = image_path.split('/')[-1]
|
156 |
+
parseq_dict[filename] = text
|
157 |
+
|
158 |
+
os.makedirs(save_dir, exist_ok=True)
|
159 |
+
with open(f"{save_dir}/{language}_test.json", 'w') as json_file:
|
160 |
+
json.dump(parseq_dict, json_file, indent=4, ensure_ascii=False)
|
161 |
+
|
162 |
+
|
163 |
+
def bstr_onImage(checkpoint, language, image_path):
|
164 |
+
"""
|
165 |
+
Runs the OCR model to process images and save the output as a JSON file.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
checkpoint (str): Path to the model checkpoint file.
|
169 |
+
language (str): Language code (e.g., 'hindi', 'english').
|
170 |
+
image_dir (str): Directory containing the images to process.
|
171 |
+
save_dir (str): Directory where the output JSON file will be saved.
|
172 |
+
|
173 |
+
Example usage:
|
174 |
+
python your_script.py --checkpoint /path/to/checkpoint.ckpt --language hindi --image_dir /path/to/images --save_dir /path/to/save
|
175 |
+
"""
|
176 |
+
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
177 |
+
|
178 |
+
if language != "english":
|
179 |
+
model = load_model(device, checkpoint)
|
180 |
+
else:
|
181 |
+
model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().to(device)
|
182 |
+
|
183 |
+
# parseq_dict = {}
|
184 |
+
# for image_path in tqdm(os.listdir(image_dir)):
|
185 |
+
# assert os.path.exists(os.path.join(image_dir, image_path)) == True, f"{image_path}"
|
186 |
+
text = get_model_output(device, model, image_path, language=f"{language}")
|
187 |
+
|
188 |
+
return text
|
189 |
+
|
190 |
+
|
191 |
+
def recognise(self, checkpoint: str, image_path: str, language: str, verbose: bool) -> str:
|
192 |
+
"""
|
193 |
+
Loads the desired model and returns the recognized word from the specified image.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
checkpoint (str): Path to the model checkpoint file.
|
197 |
+
language (str): Language code (e.g., 'hindi', 'english').
|
198 |
+
image_path (str): Path to the image for which text recognition is needed.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
str: The recognized text from the image.
|
202 |
+
"""
|
203 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
204 |
+
|
205 |
+
if language != "english":
|
206 |
+
model_path = self.ensure_model(checkpoint)
|
207 |
+
model = self.load_model(device, model_path)
|
208 |
+
else:
|
209 |
+
model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True, verbose=verbose).eval().to(device)
|
210 |
+
|
211 |
+
recognized_text = self.get_model_output(device, model, image_path)
|
212 |
+
|
213 |
+
return recognized_text
|
214 |
+
# if __name__ == '__main__':
|
215 |
+
# fire.Fire(main)
|
IndicPhotoOCR/script_identification/CLIP_identifier.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import clip
|
4 |
+
from PIL import Image
|
5 |
+
from io import BytesIO
|
6 |
+
import os
|
7 |
+
import requests
|
8 |
+
|
9 |
+
# Model information dictionary containing model paths and language subcategories
|
10 |
+
model_info = {
|
11 |
+
"hindi": {
|
12 |
+
"path": "models/clip_finetuned_hindienglish_real.pth",
|
13 |
+
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglish_real.pth",
|
14 |
+
"subcategories": ["hindi", "english"]
|
15 |
+
},
|
16 |
+
"hinengasm": {
|
17 |
+
"path": "models/clip_finetuned_hindienglishassamese_real.pth",
|
18 |
+
"url": "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishassamese_real.pth",
|
19 |
+
"subcategories": ["hindi", "english", "assamese"]
|
20 |
+
},
|
21 |
+
"hinengben": {
|
22 |
+
"path": "models/clip_finetuned_hindienglishbengali_real.pth",
|
23 |
+
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishbengali_real.pth",
|
24 |
+
"subcategories": ["hindi", "english", "bengali"]
|
25 |
+
},
|
26 |
+
"hinengguj": {
|
27 |
+
"path": "models/clip_finetuned_hindienglishgujarati_real.pth",
|
28 |
+
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishgujarati_real.pth",
|
29 |
+
"subcategories": ["hindi", "english", "gujarati"]
|
30 |
+
},
|
31 |
+
"hinengkan": {
|
32 |
+
"path": "models/clip_finetuned_hindienglishkannada_real.pth",
|
33 |
+
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishkannada_real.pth",
|
34 |
+
"subcategories": ["hindi", "english", "kannada"]
|
35 |
+
},
|
36 |
+
"hinengmal": {
|
37 |
+
"path": "models/clip_finetuned_hindienglishmalayalam_real.pth",
|
38 |
+
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmalayalam_real.pth",
|
39 |
+
"subcategories": ["hindi", "english", "malayalam"]
|
40 |
+
},
|
41 |
+
"hinengmar": {
|
42 |
+
"path": "models/clip_finetuned_hindienglishmarathi_real.pth",
|
43 |
+
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmarathi_real.pth",
|
44 |
+
"subcategories": ["hindi", "english", "marathi"]
|
45 |
+
},
|
46 |
+
"hinengmei": {
|
47 |
+
"path": "models/clip_finetuned_hindienglishmeitei_real.pth",
|
48 |
+
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmeitei_real.pth",
|
49 |
+
"subcategories": ["hindi", "english", "meitei"]
|
50 |
+
},
|
51 |
+
"hinengodi": {
|
52 |
+
"path": "models/clip_finetuned_hindienglishodia_real.pth",
|
53 |
+
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishodia_real.pth",
|
54 |
+
"subcategories": ["hindi", "english", "odia"]
|
55 |
+
},
|
56 |
+
"hinengpun": {
|
57 |
+
"path": "models/clip_finetuned_hindienglishpunjabi_real.pth",
|
58 |
+
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishpunjabi_real.pth",
|
59 |
+
"subcategories": ["hindi", "english", "punjabi"]
|
60 |
+
},
|
61 |
+
"hinengtam": {
|
62 |
+
"path": "models/clip_finetuned_hindienglishtamil_real.pth",
|
63 |
+
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtamil_real.pth",
|
64 |
+
"subcategories": ["hindi", "english", "tamil"]
|
65 |
+
},
|
66 |
+
"hinengtel": {
|
67 |
+
"path": "models/clip_finetuned_hindienglishtelugu_real.pth",
|
68 |
+
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtelugu_real.pth",
|
69 |
+
"subcategories": ["hindi", "english", "telugu"]
|
70 |
+
},
|
71 |
+
"hinengurd": {
|
72 |
+
"path": "models/clip_finetuned_hindienglishurdu_real.pth",
|
73 |
+
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishurdu_real.pth",
|
74 |
+
"subcategories": ["hindi", "english", "urdu"]
|
75 |
+
},
|
76 |
+
|
77 |
+
|
78 |
+
}
|
79 |
+
|
80 |
+
|
81 |
+
# Set device to CUDA if available, otherwise use CPU
|
82 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
83 |
+
clip_model, preprocess = clip.load("ViT-B/32", device=device)
|
84 |
+
|
85 |
+
class CLIPFineTuner(torch.nn.Module):
|
86 |
+
"""
|
87 |
+
Fine-tuning class for the CLIP model to adapt to specific tasks.
|
88 |
+
|
89 |
+
Attributes:
|
90 |
+
model (torch.nn.Module): The CLIP model to be fine-tuned.
|
91 |
+
classifier (torch.nn.Linear): A linear classifier to map features to the desired number of classes.
|
92 |
+
"""
|
93 |
+
def __init__(self, model, num_classes):
|
94 |
+
"""
|
95 |
+
Initializes the fine-tuner with the CLIP model and classifier.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
model (torch.nn.Module): The base CLIP model.
|
99 |
+
num_classes (int): The number of target classes for classification.
|
100 |
+
"""
|
101 |
+
super(CLIPFineTuner, self).__init__()
|
102 |
+
self.model = model
|
103 |
+
self.classifier = torch.nn.Linear(model.visual.output_dim, num_classes)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
"""
|
107 |
+
Forward pass for image classification.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
x (torch.Tensor): Preprocessed input tensor for an image.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
torch.Tensor: Logits for each class.
|
114 |
+
"""
|
115 |
+
with torch.no_grad():
|
116 |
+
features = self.model.encode_image(x).float() # Extract image features from CLIP model
|
117 |
+
return self.classifier(features) # Return class logits
|
118 |
+
|
119 |
+
class CLIPidentifier:
|
120 |
+
def __init__(self):
|
121 |
+
pass
|
122 |
+
|
123 |
+
# Ensure model file exists; download directly if not
|
124 |
+
def ensure_model(self, model_name):
|
125 |
+
model_path = model_info[model_name]["path"]
|
126 |
+
url = model_info[model_name]["url"]
|
127 |
+
root_model_dir = "IndicPhotoOCR/script_identification/"
|
128 |
+
model_path = os.path.join(root_model_dir, model_path)
|
129 |
+
|
130 |
+
if not os.path.exists(model_path):
|
131 |
+
print(f"Model not found locally. Downloading {model_name} from {url}...")
|
132 |
+
response = requests.get(url, stream=True)
|
133 |
+
os.makedirs(f"{root_model_dir}/models", exist_ok=True)
|
134 |
+
with open(f"{model_path}", "wb") as f:
|
135 |
+
f.write(response.content)
|
136 |
+
print(f"Downloaded model for {model_name}.")
|
137 |
+
|
138 |
+
return model_path
|
139 |
+
|
140 |
+
# Prediction function to verify and load the model
|
141 |
+
def identify(self, image_path, model_name):
|
142 |
+
"""
|
143 |
+
Predicts the class of an input image using a fine-tuned CLIP model.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
image_path (str): Path to the input image file.
|
147 |
+
model_name (str): Name of the model (e.g., hineng, hinengpun, hinengguj) as specified in `model_info`.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
dict: Contains either `predicted_class` if successful or `error` if an exception occurs.
|
151 |
+
|
152 |
+
Example usage:
|
153 |
+
result = predict("sample_image.jpg", "hinengguj")
|
154 |
+
print(result) # Output might be {'predicted_class': 'hindi'}
|
155 |
+
"""
|
156 |
+
try:
|
157 |
+
# Validate model name and retrieve associated subcategories
|
158 |
+
if model_name not in model_info:
|
159 |
+
return {"error": "Invalid model name"}
|
160 |
+
|
161 |
+
# Ensure the model file is downloaded and accessible
|
162 |
+
model_path = self.ensure_model(model_name)
|
163 |
+
|
164 |
+
|
165 |
+
subcategories = model_info[model_name]["subcategories"]
|
166 |
+
num_classes = len(subcategories)
|
167 |
+
|
168 |
+
# Load the fine-tuned model with the specified number of classes
|
169 |
+
model_ft = CLIPFineTuner(clip_model, num_classes)
|
170 |
+
model_ft.load_state_dict(torch.load(model_path, map_location=device))
|
171 |
+
model_ft = model_ft.to(device)
|
172 |
+
model_ft.eval()
|
173 |
+
|
174 |
+
# Load and preprocess the image
|
175 |
+
image = Image.open(image_path).convert("RGB")
|
176 |
+
input_tensor = preprocess(image).unsqueeze(0).to(device)
|
177 |
+
|
178 |
+
# Run the model and get the prediction
|
179 |
+
outputs = model_ft(input_tensor)
|
180 |
+
_, predicted_idx = torch.max(outputs, 1)
|
181 |
+
predicted_class = subcategories[predicted_idx.item()]
|
182 |
+
|
183 |
+
return predicted_class
|
184 |
+
|
185 |
+
except Exception as e:
|
186 |
+
return {"error": str(e)}
|
187 |
+
|
188 |
+
|
189 |
+
# if __name__ == "__main__":
|
190 |
+
# import argparse
|
191 |
+
|
192 |
+
# # Argument parser for command line usage
|
193 |
+
# parser = argparse.ArgumentParser(description="Image classification using CLIP fine-tuned model")
|
194 |
+
# parser.add_argument("image_path", type=str, help="Path to the input image")
|
195 |
+
# parser.add_argument("model_name", type=str, choices=model_info.keys(), help="Name of the model (e.g., hineng, hinengpun, hinengguj)")
|
196 |
+
|
197 |
+
# args = parser.parse_args()
|
198 |
+
|
199 |
+
# # Execute prediction with command line inputs
|
200 |
+
# result = predict(args.image_path, args.model_name)
|
201 |
+
# print(result)
|
IndicPhotoOCR/script_identification/__init__.py
ADDED
File without changes
|
IndicPhotoOCR/theme.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from typing import Iterable
|
3 |
+
import gradio as gr
|
4 |
+
from gradio.themes.base import Base
|
5 |
+
from gradio.themes.utils import colors, fonts, sizes
|
6 |
+
import time
|
7 |
+
|
8 |
+
|
9 |
+
class Seafoam(Base):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
*,
|
13 |
+
primary_hue: colors.Color | str = colors.emerald,
|
14 |
+
secondary_hue: colors.Color | str = colors.blue,
|
15 |
+
neutral_hue: colors.Color | str = colors.gray,
|
16 |
+
spacing_size: sizes.Size | str = sizes.spacing_md,
|
17 |
+
radius_size: sizes.Size | str = sizes.radius_md,
|
18 |
+
text_size: sizes.Size | str = sizes.text_lg,
|
19 |
+
font: fonts.Font
|
20 |
+
| str
|
21 |
+
| Iterable[fonts.Font | str] = (
|
22 |
+
fonts.GoogleFont("Quicksand"),
|
23 |
+
"ui-sans-serif",
|
24 |
+
"sans-serif",
|
25 |
+
),
|
26 |
+
font_mono: fonts.Font
|
27 |
+
| str
|
28 |
+
| Iterable[fonts.Font | str] = (
|
29 |
+
fonts.GoogleFont("IBM Plex Mono"),
|
30 |
+
"ui-monospace",
|
31 |
+
"monospace",
|
32 |
+
),
|
33 |
+
):
|
34 |
+
super().__init__(
|
35 |
+
primary_hue=primary_hue,
|
36 |
+
secondary_hue=secondary_hue,
|
37 |
+
neutral_hue=neutral_hue,
|
38 |
+
spacing_size=spacing_size,
|
39 |
+
radius_size=radius_size,
|
40 |
+
text_size=text_size,
|
41 |
+
font=font,
|
42 |
+
font_mono=font_mono,
|
43 |
+
)
|
IndicPhotoOCR/utils/strhub/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# from data.module import SceneTextDataModule
|
2 |
+
# from model.utils import load_from_checkpoint
|
IndicPhotoOCR/utils/strhub/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# from .module import SceneTextDataModule
|
IndicPhotoOCR/utils/strhub/data/aa_overrides.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Extends default ops to accept optional parameters."""
|
17 |
+
from functools import partial
|
18 |
+
|
19 |
+
from timm.data.auto_augment import _LEVEL_DENOM, LEVEL_TO_ARG, NAME_TO_OP, _randomly_negate, rotate
|
20 |
+
|
21 |
+
|
22 |
+
def rotate_expand(img, degrees, **kwargs):
|
23 |
+
"""Rotate operation with expand=True to avoid cutting off the characters"""
|
24 |
+
kwargs['expand'] = True
|
25 |
+
return rotate(img, degrees, **kwargs)
|
26 |
+
|
27 |
+
|
28 |
+
def _level_to_arg(level, hparams, key, default):
|
29 |
+
magnitude = hparams.get(key, default)
|
30 |
+
level = (level / _LEVEL_DENOM) * magnitude
|
31 |
+
level = _randomly_negate(level)
|
32 |
+
return (level,)
|
33 |
+
|
34 |
+
|
35 |
+
def apply():
|
36 |
+
# Overrides
|
37 |
+
NAME_TO_OP.update({
|
38 |
+
'Rotate': rotate_expand,
|
39 |
+
})
|
40 |
+
LEVEL_TO_ARG.update({
|
41 |
+
'Rotate': partial(_level_to_arg, key='rotate_deg', default=30.0),
|
42 |
+
'ShearX': partial(_level_to_arg, key='shear_x_pct', default=0.3),
|
43 |
+
'ShearY': partial(_level_to_arg, key='shear_y_pct', default=0.3),
|
44 |
+
'TranslateXRel': partial(_level_to_arg, key='translate_x_pct', default=0.45),
|
45 |
+
'TranslateYRel': partial(_level_to_arg, key='translate_y_pct', default=0.45),
|
46 |
+
})
|
IndicPhotoOCR/utils/strhub/data/augment.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
import imgaug.augmenters as iaa
|
19 |
+
import numpy as np
|
20 |
+
from PIL import Image, ImageFilter
|
21 |
+
|
22 |
+
from timm.data import auto_augment
|
23 |
+
|
24 |
+
from strhub.data import aa_overrides
|
25 |
+
|
26 |
+
aa_overrides.apply()
|
27 |
+
|
28 |
+
_OP_CACHE = {}
|
29 |
+
|
30 |
+
|
31 |
+
def _get_op(key, factory):
|
32 |
+
try:
|
33 |
+
op = _OP_CACHE[key]
|
34 |
+
except KeyError:
|
35 |
+
op = factory()
|
36 |
+
_OP_CACHE[key] = op
|
37 |
+
return op
|
38 |
+
|
39 |
+
|
40 |
+
def _get_param(level, img, max_dim_factor, min_level=1):
|
41 |
+
max_level = max(min_level, max_dim_factor * max(img.size))
|
42 |
+
return round(min(level, max_level))
|
43 |
+
|
44 |
+
|
45 |
+
def gaussian_blur(img, radius, **__):
|
46 |
+
radius = _get_param(radius, img, 0.02)
|
47 |
+
key = 'gaussian_blur_' + str(radius)
|
48 |
+
op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius))
|
49 |
+
return img.filter(op)
|
50 |
+
|
51 |
+
|
52 |
+
def motion_blur(img, k, **__):
|
53 |
+
k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values
|
54 |
+
key = 'motion_blur_' + str(k)
|
55 |
+
op = _get_op(key, lambda: iaa.MotionBlur(k))
|
56 |
+
return Image.fromarray(op(image=np.asarray(img)))
|
57 |
+
|
58 |
+
|
59 |
+
def gaussian_noise(img, scale, **_):
|
60 |
+
scale = _get_param(scale, img, 0.25) | 1 # bin to odd values
|
61 |
+
key = 'gaussian_noise_' + str(scale)
|
62 |
+
op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale))
|
63 |
+
return Image.fromarray(op(image=np.asarray(img)))
|
64 |
+
|
65 |
+
|
66 |
+
def poisson_noise(img, lam, **_):
|
67 |
+
lam = _get_param(lam, img, 0.2) | 1 # bin to odd values
|
68 |
+
key = 'poisson_noise_' + str(lam)
|
69 |
+
op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam))
|
70 |
+
return Image.fromarray(op(image=np.asarray(img)))
|
71 |
+
|
72 |
+
|
73 |
+
def _level_to_arg(level, _hparams, max):
|
74 |
+
level = max * level / auto_augment._LEVEL_DENOM
|
75 |
+
return (level,)
|
76 |
+
|
77 |
+
|
78 |
+
_RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy()
|
79 |
+
_RAND_TRANSFORMS.remove('SharpnessIncreasing') # remove, interferes with *blur ops
|
80 |
+
_RAND_TRANSFORMS.extend([
|
81 |
+
'GaussianBlur',
|
82 |
+
# 'MotionBlur',
|
83 |
+
# 'GaussianNoise',
|
84 |
+
'PoissonNoise',
|
85 |
+
])
|
86 |
+
auto_augment.LEVEL_TO_ARG.update({
|
87 |
+
'GaussianBlur': partial(_level_to_arg, max=4),
|
88 |
+
'MotionBlur': partial(_level_to_arg, max=20),
|
89 |
+
'GaussianNoise': partial(_level_to_arg, max=0.1 * 255),
|
90 |
+
'PoissonNoise': partial(_level_to_arg, max=40),
|
91 |
+
})
|
92 |
+
auto_augment.NAME_TO_OP.update({
|
93 |
+
'GaussianBlur': gaussian_blur,
|
94 |
+
'MotionBlur': motion_blur,
|
95 |
+
'GaussianNoise': gaussian_noise,
|
96 |
+
'PoissonNoise': poisson_noise,
|
97 |
+
})
|
98 |
+
|
99 |
+
|
100 |
+
def rand_augment_transform(magnitude=5, num_layers=3):
|
101 |
+
# These are tuned for magnitude=5, which means that effective magnitudes are half of these values.
|
102 |
+
hparams = {
|
103 |
+
'rotate_deg': 30,
|
104 |
+
'shear_x_pct': 0.9,
|
105 |
+
'shear_y_pct': 0.2,
|
106 |
+
'translate_x_pct': 0.10,
|
107 |
+
'translate_y_pct': 0.30,
|
108 |
+
}
|
109 |
+
ra_ops = auto_augment.rand_augment_ops(magnitude, hparams=hparams, transforms=_RAND_TRANSFORMS)
|
110 |
+
# Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice)
|
111 |
+
choice_weights = [1.0 / len(ra_ops) for _ in range(len(ra_ops))]
|
112 |
+
return auto_augment.RandAugment(ra_ops, num_layers, choice_weights)
|
IndicPhotoOCR/utils/strhub/data/dataset.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import glob
|
16 |
+
import io
|
17 |
+
import logging
|
18 |
+
import unicodedata
|
19 |
+
from pathlib import Path, PurePath
|
20 |
+
from typing import Callable, Optional, Union
|
21 |
+
|
22 |
+
import lmdb
|
23 |
+
from PIL import Image
|
24 |
+
|
25 |
+
from torch.utils.data import ConcatDataset, Dataset
|
26 |
+
|
27 |
+
from IndicPhotoOCR.utils.strhub.data.utils import CharsetAdapter
|
28 |
+
|
29 |
+
log = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
def build_tree_dataset(root: Union[PurePath, str], *args, **kwargs):
|
33 |
+
try:
|
34 |
+
kwargs.pop('root') # prevent 'root' from being passed via kwargs
|
35 |
+
except KeyError:
|
36 |
+
pass
|
37 |
+
root = Path(root).absolute()
|
38 |
+
log.info(f'dataset root:\t{root}')
|
39 |
+
datasets = []
|
40 |
+
for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True):
|
41 |
+
mdb = Path(mdb)
|
42 |
+
ds_name = str(mdb.parent.relative_to(root))
|
43 |
+
ds_root = str(mdb.parent.absolute())
|
44 |
+
dataset = LmdbDataset(ds_root, *args, **kwargs)
|
45 |
+
log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}')
|
46 |
+
datasets.append(dataset)
|
47 |
+
return ConcatDataset(datasets)
|
48 |
+
|
49 |
+
|
50 |
+
class LmdbDataset(Dataset):
|
51 |
+
"""Dataset interface to an LMDB database.
|
52 |
+
|
53 |
+
It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned
|
54 |
+
as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset.
|
55 |
+
Labels are transformed according to the charset.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
root: str,
|
61 |
+
charset: str,
|
62 |
+
max_label_len: int,
|
63 |
+
min_image_dim: int = 0,
|
64 |
+
remove_whitespace: bool = True,
|
65 |
+
normalize_unicode: bool = True,
|
66 |
+
unlabelled: bool = False,
|
67 |
+
transform: Optional[Callable] = None,
|
68 |
+
):
|
69 |
+
self._env = None
|
70 |
+
self.root = root
|
71 |
+
self.unlabelled = unlabelled
|
72 |
+
self.transform = transform
|
73 |
+
self.labels = []
|
74 |
+
self.filtered_index_list = []
|
75 |
+
self.num_samples = self._preprocess_labels(
|
76 |
+
charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim
|
77 |
+
)
|
78 |
+
|
79 |
+
def __del__(self):
|
80 |
+
if self._env is not None:
|
81 |
+
self._env.close()
|
82 |
+
self._env = None
|
83 |
+
|
84 |
+
def _create_env(self):
|
85 |
+
return lmdb.open(
|
86 |
+
self.root, max_readers=1, readonly=True, create=False, readahead=False, meminit=False, lock=False
|
87 |
+
)
|
88 |
+
|
89 |
+
@property
|
90 |
+
def env(self):
|
91 |
+
if self._env is None:
|
92 |
+
self._env = self._create_env()
|
93 |
+
return self._env
|
94 |
+
|
95 |
+
def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim):
|
96 |
+
charset_adapter = CharsetAdapter(charset)
|
97 |
+
with self._create_env() as env, env.begin() as txn:
|
98 |
+
num_samples = int(txn.get('num-samples'.encode()))
|
99 |
+
if self.unlabelled:
|
100 |
+
return num_samples
|
101 |
+
for index in range(num_samples):
|
102 |
+
index += 1 # lmdb starts with 1
|
103 |
+
label_key = f'label-{index:09d}'.encode()
|
104 |
+
label = txn.get(label_key).decode()
|
105 |
+
# Normally, whitespace is removed from the labels.
|
106 |
+
if remove_whitespace:
|
107 |
+
label = ''.join(label.split())
|
108 |
+
# Normalize unicode composites (if any) and convert to compatible ASCII characters
|
109 |
+
if normalize_unicode:
|
110 |
+
label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode()
|
111 |
+
# Filter by length before removing unsupported characters. The original label might be too long.
|
112 |
+
if len(label) > max_label_len:
|
113 |
+
continue
|
114 |
+
label = charset_adapter(label)
|
115 |
+
# We filter out samples which don't contain any supported characters
|
116 |
+
if not label:
|
117 |
+
continue
|
118 |
+
# Filter images that are too small.
|
119 |
+
if min_image_dim > 0:
|
120 |
+
img_key = f'image-{index:09d}'.encode()
|
121 |
+
buf = io.BytesIO(txn.get(img_key))
|
122 |
+
w, h = Image.open(buf).size
|
123 |
+
if w < self.min_image_dim or h < self.min_image_dim:
|
124 |
+
continue
|
125 |
+
self.labels.append(label)
|
126 |
+
self.filtered_index_list.append(index)
|
127 |
+
return len(self.labels)
|
128 |
+
|
129 |
+
def __len__(self):
|
130 |
+
return self.num_samples
|
131 |
+
|
132 |
+
def __getitem__(self, index):
|
133 |
+
if self.unlabelled:
|
134 |
+
label = index
|
135 |
+
else:
|
136 |
+
label = self.labels[index]
|
137 |
+
index = self.filtered_index_list[index]
|
138 |
+
|
139 |
+
img_key = f'image-{index:09d}'.encode()
|
140 |
+
with self.env.begin() as txn:
|
141 |
+
imgbuf = txn.get(img_key)
|
142 |
+
buf = io.BytesIO(imgbuf)
|
143 |
+
img = Image.open(buf).convert('RGB')
|
144 |
+
|
145 |
+
if self.transform is not None:
|
146 |
+
img = self.transform(img)
|
147 |
+
|
148 |
+
return img, label
|
IndicPhotoOCR/utils/strhub/data/module.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from pathlib import PurePath
|
17 |
+
from typing import Callable, Optional, Sequence
|
18 |
+
|
19 |
+
from torch.utils.data import DataLoader
|
20 |
+
from torchvision import transforms as T
|
21 |
+
|
22 |
+
import pytorch_lightning as pl
|
23 |
+
|
24 |
+
from IndicPhotoOCR.utils.strhub.data.dataset import LmdbDataset, build_tree_dataset
|
25 |
+
|
26 |
+
|
27 |
+
class SceneTextDataModule(pl.LightningDataModule):
|
28 |
+
TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80')
|
29 |
+
TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80')
|
30 |
+
TEST_NEW = ('ArT', 'COCOv1.4', 'Uber')
|
31 |
+
TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW))
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
root_dir: str,
|
36 |
+
train_dir: str,
|
37 |
+
img_size: Sequence[int],
|
38 |
+
max_label_length: int,
|
39 |
+
charset_train: str,
|
40 |
+
charset_test: str,
|
41 |
+
batch_size: int,
|
42 |
+
num_workers: int,
|
43 |
+
augment: bool,
|
44 |
+
remove_whitespace: bool = True,
|
45 |
+
normalize_unicode: bool = True,
|
46 |
+
min_image_dim: int = 0,
|
47 |
+
rotation: int = 0,
|
48 |
+
collate_fn: Optional[Callable] = None,
|
49 |
+
):
|
50 |
+
super().__init__()
|
51 |
+
self.root_dir = root_dir
|
52 |
+
self.train_dir = train_dir
|
53 |
+
self.img_size = tuple(img_size)
|
54 |
+
self.max_label_length = max_label_length
|
55 |
+
self.charset_train = charset_train
|
56 |
+
self.charset_test = charset_test
|
57 |
+
self.batch_size = batch_size
|
58 |
+
self.num_workers = num_workers
|
59 |
+
self.augment = augment
|
60 |
+
self.remove_whitespace = remove_whitespace
|
61 |
+
self.normalize_unicode = normalize_unicode
|
62 |
+
self.min_image_dim = min_image_dim
|
63 |
+
self.rotation = rotation
|
64 |
+
self.collate_fn = collate_fn
|
65 |
+
self._train_dataset = None
|
66 |
+
self._val_dataset = None
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def get_transform(img_size: tuple[int], augment: bool = False, rotation: int = 0):
|
70 |
+
transforms = []
|
71 |
+
if augment:
|
72 |
+
from .augment import rand_augment_transform
|
73 |
+
|
74 |
+
transforms.append(rand_augment_transform())
|
75 |
+
if rotation:
|
76 |
+
transforms.append(lambda img: img.rotate(rotation, expand=True))
|
77 |
+
transforms.extend([
|
78 |
+
T.Resize(img_size, T.InterpolationMode.BICUBIC),
|
79 |
+
T.ToTensor(),
|
80 |
+
T.Normalize(0.5, 0.5),
|
81 |
+
])
|
82 |
+
return T.Compose(transforms)
|
83 |
+
|
84 |
+
@property
|
85 |
+
def train_dataset(self):
|
86 |
+
if self._train_dataset is None:
|
87 |
+
transform = self.get_transform(self.img_size, self.augment)
|
88 |
+
root = PurePath(self.root_dir, 'train', self.train_dir)
|
89 |
+
self._train_dataset = build_tree_dataset(
|
90 |
+
root,
|
91 |
+
self.charset_train,
|
92 |
+
self.max_label_length,
|
93 |
+
self.min_image_dim,
|
94 |
+
self.remove_whitespace,
|
95 |
+
self.normalize_unicode,
|
96 |
+
transform=transform,
|
97 |
+
)
|
98 |
+
return self._train_dataset
|
99 |
+
|
100 |
+
@property
|
101 |
+
def val_dataset(self):
|
102 |
+
if self._val_dataset is None:
|
103 |
+
transform = self.get_transform(self.img_size)
|
104 |
+
root = PurePath(self.root_dir, 'val')
|
105 |
+
self._val_dataset = build_tree_dataset(
|
106 |
+
root,
|
107 |
+
self.charset_test,
|
108 |
+
self.max_label_length,
|
109 |
+
self.min_image_dim,
|
110 |
+
self.remove_whitespace,
|
111 |
+
self.normalize_unicode,
|
112 |
+
transform=transform,
|
113 |
+
)
|
114 |
+
return self._val_dataset
|
115 |
+
|
116 |
+
def train_dataloader(self):
|
117 |
+
return DataLoader(
|
118 |
+
self.train_dataset,
|
119 |
+
batch_size=self.batch_size,
|
120 |
+
shuffle=True,
|
121 |
+
num_workers=self.num_workers,
|
122 |
+
persistent_workers=self.num_workers > 0,
|
123 |
+
pin_memory=True,
|
124 |
+
collate_fn=self.collate_fn,
|
125 |
+
)
|
126 |
+
|
127 |
+
def val_dataloader(self):
|
128 |
+
return DataLoader(
|
129 |
+
self.val_dataset,
|
130 |
+
batch_size=self.batch_size,
|
131 |
+
num_workers=self.num_workers,
|
132 |
+
persistent_workers=self.num_workers > 0,
|
133 |
+
pin_memory=True,
|
134 |
+
collate_fn=self.collate_fn,
|
135 |
+
)
|
136 |
+
|
137 |
+
def test_dataloaders(self, subset):
|
138 |
+
transform = self.get_transform(self.img_size, rotation=self.rotation)
|
139 |
+
root = PurePath(self.root_dir, 'test')
|
140 |
+
datasets = {
|
141 |
+
s: LmdbDataset(
|
142 |
+
str(root / s),
|
143 |
+
self.charset_test,
|
144 |
+
self.max_label_length,
|
145 |
+
self.min_image_dim,
|
146 |
+
self.remove_whitespace,
|
147 |
+
self.normalize_unicode,
|
148 |
+
transform=transform,
|
149 |
+
)
|
150 |
+
for s in subset
|
151 |
+
}
|
152 |
+
return {
|
153 |
+
k: DataLoader(
|
154 |
+
v, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=self.collate_fn
|
155 |
+
)
|
156 |
+
for k, v in datasets.items()
|
157 |
+
}
|
IndicPhotoOCR/utils/strhub/data/utils.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import re
|
17 |
+
from abc import ABC, abstractmethod
|
18 |
+
from itertools import groupby
|
19 |
+
from typing import Optional
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from torch import Tensor
|
23 |
+
from torch.nn.utils.rnn import pad_sequence
|
24 |
+
|
25 |
+
|
26 |
+
class CharsetAdapter:
|
27 |
+
"""Transforms labels according to the target charset."""
|
28 |
+
|
29 |
+
def __init__(self, target_charset) -> None:
|
30 |
+
super().__init__()
|
31 |
+
self.lowercase_only = target_charset == target_charset.lower()
|
32 |
+
self.uppercase_only = target_charset == target_charset.upper()
|
33 |
+
self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
|
34 |
+
|
35 |
+
def __call__(self, label):
|
36 |
+
if self.lowercase_only:
|
37 |
+
label = label.lower()
|
38 |
+
elif self.uppercase_only:
|
39 |
+
label = label.upper()
|
40 |
+
# Remove unsupported characters
|
41 |
+
label = self.unsupported.sub('', label)
|
42 |
+
return label
|
43 |
+
|
44 |
+
|
45 |
+
class BaseTokenizer(ABC):
|
46 |
+
|
47 |
+
def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
|
48 |
+
self._itos = specials_first + tuple(charset) + specials_last
|
49 |
+
self._stoi = {s: i for i, s in enumerate(self._itos)}
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self._itos)
|
53 |
+
|
54 |
+
def _tok2ids(self, tokens: str) -> list[int]:
|
55 |
+
return [self._stoi[s] for s in tokens]
|
56 |
+
|
57 |
+
def _ids2tok(self, token_ids: list[int], join: bool = True) -> str:
|
58 |
+
tokens = [self._itos[i] for i in token_ids]
|
59 |
+
return ''.join(tokens) if join else tokens
|
60 |
+
|
61 |
+
@abstractmethod
|
62 |
+
def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
|
63 |
+
"""Encode a batch of labels to a representation suitable for the model.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
labels: List of labels. Each can be of arbitrary length.
|
67 |
+
device: Create tensor on this device.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
Batched tensor representation padded to the max label length. Shape: N, L
|
71 |
+
"""
|
72 |
+
raise NotImplementedError
|
73 |
+
|
74 |
+
@abstractmethod
|
75 |
+
def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
|
76 |
+
"""Internal method which performs the necessary filtering prior to decoding."""
|
77 |
+
raise NotImplementedError
|
78 |
+
|
79 |
+
def decode(self, token_dists: Tensor, raw: bool = False) -> tuple[list[str], list[Tensor]]:
|
80 |
+
"""Decode a batch of token distributions.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
token_dists: softmax probabilities over the token distribution. Shape: N, L, C
|
84 |
+
raw: return unprocessed labels (will return list of list of strings)
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
list of string labels (arbitrary length) and
|
88 |
+
their corresponding sequence probabilities as a list of Tensors
|
89 |
+
"""
|
90 |
+
batch_tokens = []
|
91 |
+
batch_probs = []
|
92 |
+
for dist in token_dists:
|
93 |
+
probs, ids = dist.max(-1) # greedy selection
|
94 |
+
if not raw:
|
95 |
+
probs, ids = self._filter(probs, ids)
|
96 |
+
tokens = self._ids2tok(ids, not raw)
|
97 |
+
batch_tokens.append(tokens)
|
98 |
+
batch_probs.append(probs)
|
99 |
+
return batch_tokens, batch_probs
|
100 |
+
|
101 |
+
|
102 |
+
class Tokenizer(BaseTokenizer):
|
103 |
+
BOS = '[B]'
|
104 |
+
EOS = '[E]'
|
105 |
+
PAD = '[P]'
|
106 |
+
|
107 |
+
def __init__(self, charset: str) -> None:
|
108 |
+
specials_first = (self.EOS,)
|
109 |
+
specials_last = (self.BOS, self.PAD)
|
110 |
+
super().__init__(charset, specials_first, specials_last)
|
111 |
+
self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last]
|
112 |
+
|
113 |
+
def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
|
114 |
+
batch = [
|
115 |
+
torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
|
116 |
+
for y in labels
|
117 |
+
]
|
118 |
+
return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
|
119 |
+
|
120 |
+
def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
|
121 |
+
ids = ids.tolist()
|
122 |
+
try:
|
123 |
+
eos_idx = ids.index(self.eos_id)
|
124 |
+
except ValueError:
|
125 |
+
eos_idx = len(ids) # Nothing to truncate.
|
126 |
+
# Truncate after EOS
|
127 |
+
ids = ids[:eos_idx]
|
128 |
+
probs = probs[: eos_idx + 1] # but include prob. for EOS (if it exists)
|
129 |
+
return probs, ids
|
130 |
+
|
131 |
+
|
132 |
+
class CTCTokenizer(BaseTokenizer):
|
133 |
+
BLANK = '[B]'
|
134 |
+
|
135 |
+
def __init__(self, charset: str) -> None:
|
136 |
+
# BLANK uses index == 0 by default
|
137 |
+
super().__init__(charset, specials_first=(self.BLANK,))
|
138 |
+
self.blank_id = self._stoi[self.BLANK]
|
139 |
+
|
140 |
+
def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
|
141 |
+
# We use a padded representation since we don't want to use CUDNN's CTC implementation
|
142 |
+
batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels]
|
143 |
+
return pad_sequence(batch, batch_first=True, padding_value=self.blank_id)
|
144 |
+
|
145 |
+
def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
|
146 |
+
# Best path decoding:
|
147 |
+
ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens
|
148 |
+
ids = [x for x in ids if x != self.blank_id] # Remove BLANKs
|
149 |
+
# `probs` is just pass-through since all positions are considered part of the path
|
150 |
+
return probs, ids
|
IndicPhotoOCR/utils/strhub/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# from .utils import load_from_checkpoint
|
IndicPhotoOCR/utils/strhub/models/abinet/LICENSE
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ABINet for non-commercial purposes
|
2 |
+
|
3 |
+
Copyright (c) 2021, USTC
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without
|
7 |
+
modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
10 |
+
list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
this list of conditions and the following disclaimer in the documentation
|
14 |
+
and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
17 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
18 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
19 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
20 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
21 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
22 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
23 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
24 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
IndicPhotoOCR/utils/strhub/models/abinet/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""
|
2 |
+
Fang, Shancheng, Hongtao, Xie, Yuxin, Wang, Zhendong, Mao, and Yongdong, Zhang.
|
3 |
+
"Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition." .
|
4 |
+
In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 7098-7107).2021.
|
5 |
+
|
6 |
+
https://arxiv.org/abs/2103.06495
|
7 |
+
|
8 |
+
All source files, except `system.py`, are based on the implementation listed below,
|
9 |
+
and hence are released under the license of the original.
|
10 |
+
|
11 |
+
Source: https://github.com/FangShancheng/ABINet
|
12 |
+
License: 2-clause BSD License (see included LICENSE file)
|
13 |
+
"""
|
IndicPhotoOCR/utils/strhub/models/abinet/attention.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .transformer import PositionalEncoding
|
5 |
+
|
6 |
+
|
7 |
+
class Attention(nn.Module):
|
8 |
+
def __init__(self, in_channels=512, max_length=25, n_feature=256):
|
9 |
+
super().__init__()
|
10 |
+
self.max_length = max_length
|
11 |
+
|
12 |
+
self.f0_embedding = nn.Embedding(max_length, in_channels)
|
13 |
+
self.w0 = nn.Linear(max_length, n_feature)
|
14 |
+
self.wv = nn.Linear(in_channels, in_channels)
|
15 |
+
self.we = nn.Linear(in_channels, max_length)
|
16 |
+
|
17 |
+
self.active = nn.Tanh()
|
18 |
+
self.softmax = nn.Softmax(dim=2)
|
19 |
+
|
20 |
+
def forward(self, enc_output):
|
21 |
+
enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2)
|
22 |
+
reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device)
|
23 |
+
reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S)
|
24 |
+
reading_order_embed = self.f0_embedding(reading_order) # b,25,512
|
25 |
+
|
26 |
+
t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256
|
27 |
+
t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512
|
28 |
+
|
29 |
+
attn = self.we(t) # b,256,25
|
30 |
+
attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256
|
31 |
+
g_output = torch.bmm(attn, enc_output) # b,25,512
|
32 |
+
return g_output, attn.view(*attn.shape[:2], 8, 32)
|
33 |
+
|
34 |
+
|
35 |
+
def encoder_layer(in_c, out_c, k=3, s=2, p=1):
|
36 |
+
return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
|
37 |
+
nn.BatchNorm2d(out_c),
|
38 |
+
nn.ReLU(True))
|
39 |
+
|
40 |
+
|
41 |
+
def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None):
|
42 |
+
align_corners = None if mode == 'nearest' else True
|
43 |
+
return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor,
|
44 |
+
mode=mode, align_corners=align_corners),
|
45 |
+
nn.Conv2d(in_c, out_c, k, s, p),
|
46 |
+
nn.BatchNorm2d(out_c),
|
47 |
+
nn.ReLU(True))
|
48 |
+
|
49 |
+
|
50 |
+
class PositionAttention(nn.Module):
|
51 |
+
def __init__(self, max_length, in_channels=512, num_channels=64,
|
52 |
+
h=8, w=32, mode='nearest', **kwargs):
|
53 |
+
super().__init__()
|
54 |
+
self.max_length = max_length
|
55 |
+
self.k_encoder = nn.Sequential(
|
56 |
+
encoder_layer(in_channels, num_channels, s=(1, 2)),
|
57 |
+
encoder_layer(num_channels, num_channels, s=(2, 2)),
|
58 |
+
encoder_layer(num_channels, num_channels, s=(2, 2)),
|
59 |
+
encoder_layer(num_channels, num_channels, s=(2, 2))
|
60 |
+
)
|
61 |
+
self.k_decoder = nn.Sequential(
|
62 |
+
decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
|
63 |
+
decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
|
64 |
+
decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
|
65 |
+
decoder_layer(num_channels, in_channels, size=(h, w), mode=mode)
|
66 |
+
)
|
67 |
+
|
68 |
+
self.pos_encoder = PositionalEncoding(in_channels, dropout=0., max_len=max_length)
|
69 |
+
self.project = nn.Linear(in_channels, in_channels)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
N, E, H, W = x.size()
|
73 |
+
k, v = x, x # (N, E, H, W)
|
74 |
+
|
75 |
+
# calculate key vector
|
76 |
+
features = []
|
77 |
+
for i in range(0, len(self.k_encoder)):
|
78 |
+
k = self.k_encoder[i](k)
|
79 |
+
features.append(k)
|
80 |
+
for i in range(0, len(self.k_decoder) - 1):
|
81 |
+
k = self.k_decoder[i](k)
|
82 |
+
k = k + features[len(self.k_decoder) - 2 - i]
|
83 |
+
k = self.k_decoder[-1](k)
|
84 |
+
|
85 |
+
# calculate query vector
|
86 |
+
# TODO q=f(q,k)
|
87 |
+
zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E)
|
88 |
+
q = self.pos_encoder(zeros) # (T, N, E)
|
89 |
+
q = q.permute(1, 0, 2) # (N, T, E)
|
90 |
+
q = self.project(q) # (N, T, E)
|
91 |
+
|
92 |
+
# calculate attention
|
93 |
+
attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
|
94 |
+
attn_scores = attn_scores / (E ** 0.5)
|
95 |
+
attn_scores = torch.softmax(attn_scores, dim=-1)
|
96 |
+
|
97 |
+
v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
|
98 |
+
attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
|
99 |
+
|
100 |
+
return attn_vecs, attn_scores.view(N, -1, H, W)
|
IndicPhotoOCR/utils/strhub/models/abinet/backbone.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from torch.nn import TransformerEncoderLayer, TransformerEncoder
|
3 |
+
|
4 |
+
from .resnet import resnet45
|
5 |
+
from .transformer import PositionalEncoding
|
6 |
+
|
7 |
+
|
8 |
+
class ResTranformer(nn.Module):
|
9 |
+
def __init__(self, d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', backbone_ln=2):
|
10 |
+
super().__init__()
|
11 |
+
self.resnet = resnet45()
|
12 |
+
self.pos_encoder = PositionalEncoding(d_model, max_len=8 * 32)
|
13 |
+
encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead,
|
14 |
+
dim_feedforward=d_inner, dropout=dropout, activation=activation)
|
15 |
+
self.transformer = TransformerEncoder(encoder_layer, backbone_ln)
|
16 |
+
|
17 |
+
def forward(self, images):
|
18 |
+
feature = self.resnet(images)
|
19 |
+
n, c, h, w = feature.shape
|
20 |
+
feature = feature.view(n, c, -1).permute(2, 0, 1)
|
21 |
+
feature = self.pos_encoder(feature)
|
22 |
+
feature = self.transformer(feature)
|
23 |
+
feature = feature.permute(1, 2, 0).view(n, c, h, w)
|
24 |
+
return feature
|
IndicPhotoOCR/utils/strhub/models/abinet/model.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class Model(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, dataset_max_length: int, null_label: int):
|
8 |
+
super().__init__()
|
9 |
+
self.max_length = dataset_max_length + 1 # additional stop token
|
10 |
+
self.null_label = null_label
|
11 |
+
|
12 |
+
def _get_length(self, logit, dim=-1):
|
13 |
+
""" Greed decoder to obtain length from logit"""
|
14 |
+
out = (logit.argmax(dim=-1) == self.null_label)
|
15 |
+
abn = out.any(dim)
|
16 |
+
out = ((out.cumsum(dim) == 1) & out).max(dim)[1]
|
17 |
+
out = out + 1 # additional end token
|
18 |
+
out = torch.where(abn, out, out.new_tensor(logit.shape[1], device=out.device))
|
19 |
+
return out
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def _get_padding_mask(length, max_length):
|
23 |
+
length = length.unsqueeze(-1)
|
24 |
+
grid = torch.arange(0, max_length, device=length.device).unsqueeze(0)
|
25 |
+
return grid >= length
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def _get_location_mask(sz, device=None):
|
29 |
+
mask = torch.eye(sz, device=device)
|
30 |
+
mask = mask.float().masked_fill(mask == 1, float('-inf'))
|
31 |
+
return mask
|
IndicPhotoOCR/utils/strhub/models/abinet/model_abinet_iter.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from .model_alignment import BaseAlignment
|
5 |
+
from .model_language import BCNLanguage
|
6 |
+
from .model_vision import BaseVision
|
7 |
+
|
8 |
+
|
9 |
+
class ABINetIterModel(nn.Module):
|
10 |
+
def __init__(self, dataset_max_length, null_label, num_classes, iter_size=1,
|
11 |
+
d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu',
|
12 |
+
v_loss_weight=1., v_attention='position', v_attention_mode='nearest',
|
13 |
+
v_backbone='transformer', v_num_layers=2,
|
14 |
+
l_loss_weight=1., l_num_layers=4, l_detach=True, l_use_self_attn=False,
|
15 |
+
a_loss_weight=1.):
|
16 |
+
super().__init__()
|
17 |
+
self.iter_size = iter_size
|
18 |
+
self.vision = BaseVision(dataset_max_length, null_label, num_classes, v_attention, v_attention_mode,
|
19 |
+
v_loss_weight, d_model, nhead, d_inner, dropout, activation, v_backbone, v_num_layers)
|
20 |
+
self.language = BCNLanguage(dataset_max_length, null_label, num_classes, d_model, nhead, d_inner, dropout,
|
21 |
+
activation, l_num_layers, l_detach, l_use_self_attn, l_loss_weight)
|
22 |
+
self.alignment = BaseAlignment(dataset_max_length, null_label, num_classes, d_model, a_loss_weight)
|
23 |
+
|
24 |
+
def forward(self, images):
|
25 |
+
v_res = self.vision(images)
|
26 |
+
a_res = v_res
|
27 |
+
all_l_res, all_a_res = [], []
|
28 |
+
for _ in range(self.iter_size):
|
29 |
+
tokens = torch.softmax(a_res['logits'], dim=-1)
|
30 |
+
lengths = a_res['pt_lengths']
|
31 |
+
lengths.clamp_(2, self.language.max_length) # TODO:move to langauge model
|
32 |
+
l_res = self.language(tokens, lengths)
|
33 |
+
all_l_res.append(l_res)
|
34 |
+
a_res = self.alignment(l_res['feature'], v_res['feature'])
|
35 |
+
all_a_res.append(a_res)
|
36 |
+
if self.training:
|
37 |
+
return all_a_res, all_l_res, v_res
|
38 |
+
else:
|
39 |
+
return a_res, all_l_res[-1], v_res
|
IndicPhotoOCR/utils/strhub/models/abinet/model_alignment.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .model import Model
|
5 |
+
|
6 |
+
|
7 |
+
class BaseAlignment(Model):
|
8 |
+
def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, loss_weight=1.0):
|
9 |
+
super().__init__(dataset_max_length, null_label)
|
10 |
+
self.loss_weight = loss_weight
|
11 |
+
self.w_att = nn.Linear(2 * d_model, d_model)
|
12 |
+
self.cls = nn.Linear(d_model, num_classes)
|
13 |
+
|
14 |
+
def forward(self, l_feature, v_feature):
|
15 |
+
"""
|
16 |
+
Args:
|
17 |
+
l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
|
18 |
+
v_feature: (N, T, E) shape the same as l_feature
|
19 |
+
"""
|
20 |
+
f = torch.cat((l_feature, v_feature), dim=2)
|
21 |
+
f_att = torch.sigmoid(self.w_att(f))
|
22 |
+
output = f_att * v_feature + (1 - f_att) * l_feature
|
23 |
+
|
24 |
+
logits = self.cls(output) # (N, T, C)
|
25 |
+
pt_lengths = self._get_length(logits)
|
26 |
+
|
27 |
+
return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight': self.loss_weight,
|
28 |
+
'name': 'alignment'}
|
IndicPhotoOCR/utils/strhub/models/abinet/model_language.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from .model import Model
|
4 |
+
from .transformer import PositionalEncoding, TransformerDecoderLayer, TransformerDecoder
|
5 |
+
|
6 |
+
|
7 |
+
class BCNLanguage(Model):
|
8 |
+
def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, nhead=8, d_inner=2048, dropout=0.1,
|
9 |
+
activation='relu', num_layers=4, detach=True, use_self_attn=False, loss_weight=1.0,
|
10 |
+
global_debug=False):
|
11 |
+
super().__init__(dataset_max_length, null_label)
|
12 |
+
self.detach = detach
|
13 |
+
self.loss_weight = loss_weight
|
14 |
+
self.proj = nn.Linear(num_classes, d_model, False)
|
15 |
+
self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length)
|
16 |
+
self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length)
|
17 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout,
|
18 |
+
activation, self_attn=use_self_attn, debug=global_debug)
|
19 |
+
self.model = TransformerDecoder(decoder_layer, num_layers)
|
20 |
+
self.cls = nn.Linear(d_model, num_classes)
|
21 |
+
|
22 |
+
def forward(self, tokens, lengths):
|
23 |
+
"""
|
24 |
+
Args:
|
25 |
+
tokens: (N, T, C) where T is length, N is batch size and C is classes number
|
26 |
+
lengths: (N,)
|
27 |
+
"""
|
28 |
+
if self.detach:
|
29 |
+
tokens = tokens.detach()
|
30 |
+
embed = self.proj(tokens) # (N, T, E)
|
31 |
+
embed = embed.permute(1, 0, 2) # (T, N, E)
|
32 |
+
embed = self.token_encoder(embed) # (T, N, E)
|
33 |
+
padding_mask = self._get_padding_mask(lengths, self.max_length)
|
34 |
+
|
35 |
+
zeros = embed.new_zeros(*embed.shape)
|
36 |
+
qeury = self.pos_encoder(zeros)
|
37 |
+
location_mask = self._get_location_mask(self.max_length, tokens.device)
|
38 |
+
output = self.model(qeury, embed,
|
39 |
+
tgt_key_padding_mask=padding_mask,
|
40 |
+
memory_mask=location_mask,
|
41 |
+
memory_key_padding_mask=padding_mask) # (T, N, E)
|
42 |
+
output = output.permute(1, 0, 2) # (N, T, E)
|
43 |
+
|
44 |
+
logits = self.cls(output) # (N, T, C)
|
45 |
+
pt_lengths = self._get_length(logits)
|
46 |
+
|
47 |
+
res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths,
|
48 |
+
'loss_weight': self.loss_weight, 'name': 'language'}
|
49 |
+
return res
|
IndicPhotoOCR/utils/strhub/models/abinet/model_vision.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
from .attention import PositionAttention, Attention
|
4 |
+
from .backbone import ResTranformer
|
5 |
+
from .model import Model
|
6 |
+
from .resnet import resnet45
|
7 |
+
|
8 |
+
|
9 |
+
class BaseVision(Model):
|
10 |
+
def __init__(self, dataset_max_length, null_label, num_classes,
|
11 |
+
attention='position', attention_mode='nearest', loss_weight=1.0,
|
12 |
+
d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu',
|
13 |
+
backbone='transformer', backbone_ln=2):
|
14 |
+
super().__init__(dataset_max_length, null_label)
|
15 |
+
self.loss_weight = loss_weight
|
16 |
+
self.out_channels = d_model
|
17 |
+
|
18 |
+
if backbone == 'transformer':
|
19 |
+
self.backbone = ResTranformer(d_model, nhead, d_inner, dropout, activation, backbone_ln)
|
20 |
+
else:
|
21 |
+
self.backbone = resnet45()
|
22 |
+
|
23 |
+
if attention == 'position':
|
24 |
+
self.attention = PositionAttention(
|
25 |
+
max_length=self.max_length,
|
26 |
+
mode=attention_mode
|
27 |
+
)
|
28 |
+
elif attention == 'attention':
|
29 |
+
self.attention = Attention(
|
30 |
+
max_length=self.max_length,
|
31 |
+
n_feature=8 * 32,
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
raise ValueError(f'invalid attention: {attention}')
|
35 |
+
|
36 |
+
self.cls = nn.Linear(self.out_channels, num_classes)
|
37 |
+
|
38 |
+
def forward(self, images):
|
39 |
+
features = self.backbone(images) # (N, E, H, W)
|
40 |
+
attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W)
|
41 |
+
logits = self.cls(attn_vecs) # (N, T, C)
|
42 |
+
pt_lengths = self._get_length(logits)
|
43 |
+
|
44 |
+
return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths,
|
45 |
+
'attn_scores': attn_scores, 'loss_weight': self.loss_weight, 'name': 'vision'}
|
IndicPhotoOCR/utils/strhub/models/abinet/resnet.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Callable
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from torchvision.models import resnet
|
6 |
+
|
7 |
+
|
8 |
+
class BasicBlock(resnet.BasicBlock):
|
9 |
+
|
10 |
+
def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None,
|
11 |
+
groups: int = 1, base_width: int = 64, dilation: int = 1,
|
12 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
|
13 |
+
super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)
|
14 |
+
self.conv1 = resnet.conv1x1(inplanes, planes)
|
15 |
+
self.conv2 = resnet.conv3x3(planes, planes, stride)
|
16 |
+
|
17 |
+
|
18 |
+
class ResNet(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self, block, layers):
|
21 |
+
super().__init__()
|
22 |
+
self.inplanes = 32
|
23 |
+
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1,
|
24 |
+
bias=False)
|
25 |
+
self.bn1 = nn.BatchNorm2d(32)
|
26 |
+
self.relu = nn.ReLU(inplace=True)
|
27 |
+
|
28 |
+
self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
|
29 |
+
self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
|
30 |
+
self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
|
31 |
+
self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
|
32 |
+
self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
|
33 |
+
|
34 |
+
for m in self.modules():
|
35 |
+
if isinstance(m, nn.Conv2d):
|
36 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
37 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
38 |
+
elif isinstance(m, nn.BatchNorm2d):
|
39 |
+
m.weight.data.fill_(1)
|
40 |
+
m.bias.data.zero_()
|
41 |
+
|
42 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
43 |
+
downsample = None
|
44 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
45 |
+
downsample = nn.Sequential(
|
46 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
47 |
+
kernel_size=1, stride=stride, bias=False),
|
48 |
+
nn.BatchNorm2d(planes * block.expansion),
|
49 |
+
)
|
50 |
+
|
51 |
+
layers = []
|
52 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
53 |
+
self.inplanes = planes * block.expansion
|
54 |
+
for i in range(1, blocks):
|
55 |
+
layers.append(block(self.inplanes, planes))
|
56 |
+
|
57 |
+
return nn.Sequential(*layers)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x = self.conv1(x)
|
61 |
+
x = self.bn1(x)
|
62 |
+
x = self.relu(x)
|
63 |
+
x = self.layer1(x)
|
64 |
+
x = self.layer2(x)
|
65 |
+
x = self.layer3(x)
|
66 |
+
x = self.layer4(x)
|
67 |
+
x = self.layer5(x)
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
def resnet45():
|
72 |
+
return ResNet(BasicBlock, [3, 4, 6, 6, 3])
|
IndicPhotoOCR/utils/strhub/models/abinet/system.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import logging
|
17 |
+
import math
|
18 |
+
from typing import Any, Optional
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from torch import Tensor, nn
|
23 |
+
from torch.optim import AdamW
|
24 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
25 |
+
|
26 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
27 |
+
from timm.optim.optim_factory import param_groups_weight_decay
|
28 |
+
|
29 |
+
from strhub.models.base import CrossEntropySystem
|
30 |
+
from strhub.models.utils import init_weights
|
31 |
+
|
32 |
+
from .model_abinet_iter import ABINetIterModel as Model
|
33 |
+
|
34 |
+
log = logging.getLogger(__name__)
|
35 |
+
|
36 |
+
|
37 |
+
class ABINet(CrossEntropySystem):
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
charset_train: str,
|
42 |
+
charset_test: str,
|
43 |
+
max_label_length: int,
|
44 |
+
batch_size: int,
|
45 |
+
lr: float,
|
46 |
+
warmup_pct: float,
|
47 |
+
weight_decay: float,
|
48 |
+
iter_size: int,
|
49 |
+
d_model: int,
|
50 |
+
nhead: int,
|
51 |
+
d_inner: int,
|
52 |
+
dropout: float,
|
53 |
+
activation: str,
|
54 |
+
v_loss_weight: float,
|
55 |
+
v_attention: str,
|
56 |
+
v_attention_mode: str,
|
57 |
+
v_backbone: str,
|
58 |
+
v_num_layers: int,
|
59 |
+
l_loss_weight: float,
|
60 |
+
l_num_layers: int,
|
61 |
+
l_detach: bool,
|
62 |
+
l_use_self_attn: bool,
|
63 |
+
l_lr: float,
|
64 |
+
a_loss_weight: float,
|
65 |
+
lm_only: bool = False,
|
66 |
+
**kwargs,
|
67 |
+
) -> None:
|
68 |
+
super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
|
69 |
+
self.scheduler = None
|
70 |
+
self.save_hyperparameters()
|
71 |
+
self.max_label_length = max_label_length
|
72 |
+
self.num_classes = len(self.tokenizer) - 2 # We don't predict <bos> nor <pad>
|
73 |
+
self.model = Model(
|
74 |
+
max_label_length,
|
75 |
+
self.eos_id,
|
76 |
+
self.num_classes,
|
77 |
+
iter_size,
|
78 |
+
d_model,
|
79 |
+
nhead,
|
80 |
+
d_inner,
|
81 |
+
dropout,
|
82 |
+
activation,
|
83 |
+
v_loss_weight,
|
84 |
+
v_attention,
|
85 |
+
v_attention_mode,
|
86 |
+
v_backbone,
|
87 |
+
v_num_layers,
|
88 |
+
l_loss_weight,
|
89 |
+
l_num_layers,
|
90 |
+
l_detach,
|
91 |
+
l_use_self_attn,
|
92 |
+
a_loss_weight,
|
93 |
+
)
|
94 |
+
self.model.apply(init_weights)
|
95 |
+
# FIXME: doesn't support resumption from checkpoint yet
|
96 |
+
self._reset_alignment = True
|
97 |
+
self._reset_optimizers = True
|
98 |
+
self.l_lr = l_lr
|
99 |
+
self.lm_only = lm_only
|
100 |
+
# Train LM only. Freeze other submodels.
|
101 |
+
if lm_only:
|
102 |
+
self.l_lr = lr # for tuning
|
103 |
+
self.model.vision.requires_grad_(False)
|
104 |
+
self.model.alignment.requires_grad_(False)
|
105 |
+
|
106 |
+
@property
|
107 |
+
def _pretraining(self):
|
108 |
+
# In the original work, VM was pretrained for 8 epochs while full model was trained for an additional 10 epochs.
|
109 |
+
total_steps = self.trainer.estimated_stepping_batches * self.trainer.accumulate_grad_batches
|
110 |
+
return self.global_step < (8 / (8 + 10)) * total_steps
|
111 |
+
|
112 |
+
@torch.jit.ignore
|
113 |
+
def no_weight_decay(self):
|
114 |
+
return {'model.language.proj.weight'}
|
115 |
+
|
116 |
+
def _add_weight_decay(self, model: nn.Module, skip_list=()):
|
117 |
+
if self.weight_decay:
|
118 |
+
return param_groups_weight_decay(model, self.weight_decay, skip_list)
|
119 |
+
else:
|
120 |
+
return [{'params': model.parameters()}]
|
121 |
+
|
122 |
+
def configure_optimizers(self):
|
123 |
+
agb = self.trainer.accumulate_grad_batches
|
124 |
+
# Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
|
125 |
+
lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.0
|
126 |
+
lr = lr_scale * self.lr
|
127 |
+
l_lr = lr_scale * self.l_lr
|
128 |
+
params = []
|
129 |
+
params.extend(self._add_weight_decay(self.model.vision))
|
130 |
+
params.extend(self._add_weight_decay(self.model.alignment))
|
131 |
+
# We use a different learning rate for the LM.
|
132 |
+
for p in self._add_weight_decay(self.model.language, ('proj.weight',)):
|
133 |
+
p['lr'] = l_lr
|
134 |
+
params.append(p)
|
135 |
+
max_lr = [p.get('lr', lr) for p in params]
|
136 |
+
optim = AdamW(params, lr)
|
137 |
+
self.scheduler = OneCycleLR(
|
138 |
+
optim, max_lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, cycle_momentum=False
|
139 |
+
)
|
140 |
+
return {'optimizer': optim, 'lr_scheduler': {'scheduler': self.scheduler, 'interval': 'step'}}
|
141 |
+
|
142 |
+
def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
|
143 |
+
max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
|
144 |
+
logits = self.model.forward(images)[0]['logits']
|
145 |
+
return logits[:, : max_length + 1] # truncate
|
146 |
+
|
147 |
+
def calc_loss(self, targets, *res_lists) -> Tensor:
|
148 |
+
total_loss = 0
|
149 |
+
for res_list in res_lists:
|
150 |
+
loss = 0
|
151 |
+
if isinstance(res_list, dict):
|
152 |
+
res_list = [res_list]
|
153 |
+
for res in res_list:
|
154 |
+
logits = res['logits'].flatten(end_dim=1)
|
155 |
+
loss += F.cross_entropy(logits, targets.flatten(), ignore_index=self.pad_id)
|
156 |
+
loss /= len(res_list)
|
157 |
+
self.log('loss_' + res_list[0]['name'], loss)
|
158 |
+
total_loss += res_list[0]['loss_weight'] * loss
|
159 |
+
return total_loss
|
160 |
+
|
161 |
+
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
|
162 |
+
if not self._pretraining and self._reset_optimizers:
|
163 |
+
log.info('Pretraining ends. Updating base LRs.')
|
164 |
+
self._reset_optimizers = False
|
165 |
+
# Make base_lr the same for all groups
|
166 |
+
base_lr = self.scheduler.base_lrs[0] # base_lr of group 0 - VM
|
167 |
+
self.scheduler.base_lrs = [base_lr] * len(self.scheduler.base_lrs)
|
168 |
+
|
169 |
+
def _prepare_inputs_and_targets(self, labels):
|
170 |
+
# Use dummy label to ensure sequence length is constant.
|
171 |
+
dummy = ['0' * self.max_label_length]
|
172 |
+
targets = self.tokenizer.encode(dummy + list(labels), self.device)[1:]
|
173 |
+
targets = targets[:, 1:] # remove <bos>. Unused here.
|
174 |
+
# Inputs are padded with eos_id
|
175 |
+
inputs = torch.where(targets == self.pad_id, self.eos_id, targets)
|
176 |
+
inputs = F.one_hot(inputs, self.num_classes).float()
|
177 |
+
lengths = torch.as_tensor(list(map(len, labels)), device=self.device) + 1 # +1 for eos
|
178 |
+
return inputs, lengths, targets
|
179 |
+
|
180 |
+
def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
181 |
+
images, labels = batch
|
182 |
+
inputs, lengths, targets = self._prepare_inputs_and_targets(labels)
|
183 |
+
if self.lm_only:
|
184 |
+
l_res = self.model.language(inputs, lengths)
|
185 |
+
loss = self.calc_loss(targets, l_res)
|
186 |
+
# Pretrain submodels independently first
|
187 |
+
elif self._pretraining:
|
188 |
+
# Vision
|
189 |
+
v_res = self.model.vision(images)
|
190 |
+
# Language
|
191 |
+
l_res = self.model.language(inputs, lengths)
|
192 |
+
# We also train the alignment model to 'satisfy' DDP requirements (all parameters should be used).
|
193 |
+
# We'll reset its parameters prior to joint training.
|
194 |
+
a_res = self.model.alignment(l_res['feature'].detach(), v_res['feature'].detach())
|
195 |
+
loss = self.calc_loss(targets, v_res, l_res, a_res)
|
196 |
+
else:
|
197 |
+
# Reset alignment model's parameters once prior to full model training.
|
198 |
+
if self._reset_alignment:
|
199 |
+
log.info('Pretraining ends. Resetting alignment model.')
|
200 |
+
self._reset_alignment = False
|
201 |
+
self.model.alignment.apply(init_weights)
|
202 |
+
all_a_res, all_l_res, v_res = self.model.forward(images)
|
203 |
+
loss = self.calc_loss(targets, v_res, all_l_res, all_a_res)
|
204 |
+
self.log('loss', loss)
|
205 |
+
return loss
|
206 |
+
|
207 |
+
def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
|
208 |
+
if self.lm_only:
|
209 |
+
inputs, lengths, targets = self._prepare_inputs_and_targets(labels)
|
210 |
+
l_res = self.model.language(inputs, lengths)
|
211 |
+
loss = self.calc_loss(targets, l_res)
|
212 |
+
loss_numel = (targets != self.pad_id).sum()
|
213 |
+
return l_res['logits'], loss, loss_numel
|
214 |
+
else:
|
215 |
+
return super().forward_logits_loss(images, labels)
|
IndicPhotoOCR/utils/strhub/models/abinet/transformer.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn.modules.transformer import _get_activation_fn, _get_clones
|
7 |
+
|
8 |
+
|
9 |
+
class TransformerDecoder(nn.Module):
|
10 |
+
r"""TransformerDecoder is a stack of N decoder layers
|
11 |
+
|
12 |
+
Args:
|
13 |
+
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
14 |
+
num_layers: the number of sub-decoder-layers in the decoder (required).
|
15 |
+
norm: the layer normalization component (optional).
|
16 |
+
|
17 |
+
Examples::
|
18 |
+
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
19 |
+
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
|
20 |
+
>>> memory = torch.rand(10, 32, 512)
|
21 |
+
>>> tgt = torch.rand(20, 32, 512)
|
22 |
+
>>> out = transformer_decoder(tgt, memory)
|
23 |
+
"""
|
24 |
+
__constants__ = ['norm']
|
25 |
+
|
26 |
+
def __init__(self, decoder_layer, num_layers, norm=None):
|
27 |
+
super(TransformerDecoder, self).__init__()
|
28 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
29 |
+
self.num_layers = num_layers
|
30 |
+
self.norm = norm
|
31 |
+
|
32 |
+
def forward(self, tgt, memory, memory2=None, tgt_mask=None,
|
33 |
+
memory_mask=None, memory_mask2=None, tgt_key_padding_mask=None,
|
34 |
+
memory_key_padding_mask=None, memory_key_padding_mask2=None):
|
35 |
+
# type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
|
36 |
+
r"""Pass the inputs (and mask) through the decoder layer in turn.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
tgt: the sequence to the decoder (required).
|
40 |
+
memory: the sequence from the last layer of the encoder (required).
|
41 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
42 |
+
memory_mask: the mask for the memory sequence (optional).
|
43 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
44 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
45 |
+
|
46 |
+
Shape:
|
47 |
+
see the docs in Transformer class.
|
48 |
+
"""
|
49 |
+
output = tgt
|
50 |
+
|
51 |
+
for mod in self.layers:
|
52 |
+
output = mod(output, memory, memory2=memory2, tgt_mask=tgt_mask,
|
53 |
+
memory_mask=memory_mask, memory_mask2=memory_mask2,
|
54 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
55 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
56 |
+
memory_key_padding_mask2=memory_key_padding_mask2)
|
57 |
+
|
58 |
+
if self.norm is not None:
|
59 |
+
output = self.norm(output)
|
60 |
+
|
61 |
+
return output
|
62 |
+
|
63 |
+
|
64 |
+
class TransformerDecoderLayer(nn.Module):
|
65 |
+
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
66 |
+
This standard decoder layer is based on the paper "Attention Is All You Need".
|
67 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
68 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
69 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
70 |
+
in a different way during application.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
d_model: the number of expected features in the input (required).
|
74 |
+
nhead: the number of heads in the multiheadattention models (required).
|
75 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
76 |
+
dropout: the dropout value (default=0.1).
|
77 |
+
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
78 |
+
|
79 |
+
Examples::
|
80 |
+
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
81 |
+
>>> memory = torch.rand(10, 32, 512)
|
82 |
+
>>> tgt = torch.rand(20, 32, 512)
|
83 |
+
>>> out = decoder_layer(tgt, memory)
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
87 |
+
activation="relu", self_attn=True, siamese=False, debug=False):
|
88 |
+
super().__init__()
|
89 |
+
self.has_self_attn, self.siamese = self_attn, siamese
|
90 |
+
self.debug = debug
|
91 |
+
if self.has_self_attn:
|
92 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
93 |
+
self.norm1 = nn.LayerNorm(d_model)
|
94 |
+
self.dropout1 = nn.Dropout(dropout)
|
95 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
96 |
+
# Implementation of Feedforward model
|
97 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
98 |
+
self.dropout = nn.Dropout(dropout)
|
99 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
100 |
+
|
101 |
+
self.norm2 = nn.LayerNorm(d_model)
|
102 |
+
self.norm3 = nn.LayerNorm(d_model)
|
103 |
+
self.dropout2 = nn.Dropout(dropout)
|
104 |
+
self.dropout3 = nn.Dropout(dropout)
|
105 |
+
if self.siamese:
|
106 |
+
self.multihead_attn2 = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
107 |
+
|
108 |
+
self.activation = _get_activation_fn(activation)
|
109 |
+
|
110 |
+
def __setstate__(self, state):
|
111 |
+
if 'activation' not in state:
|
112 |
+
state['activation'] = F.relu
|
113 |
+
super().__setstate__(state)
|
114 |
+
|
115 |
+
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
|
116 |
+
tgt_key_padding_mask=None, memory_key_padding_mask=None,
|
117 |
+
memory2=None, memory_mask2=None, memory_key_padding_mask2=None):
|
118 |
+
# type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
|
119 |
+
r"""Pass the inputs (and mask) through the decoder layer.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
tgt: the sequence to the decoder layer (required).
|
123 |
+
memory: the sequence from the last layer of the encoder (required).
|
124 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
125 |
+
memory_mask: the mask for the memory sequence (optional).
|
126 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
127 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
128 |
+
|
129 |
+
Shape:
|
130 |
+
see the docs in Transformer class.
|
131 |
+
"""
|
132 |
+
if self.has_self_attn:
|
133 |
+
tgt2, attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
|
134 |
+
key_padding_mask=tgt_key_padding_mask)
|
135 |
+
tgt = tgt + self.dropout1(tgt2)
|
136 |
+
tgt = self.norm1(tgt)
|
137 |
+
if self.debug: self.attn = attn
|
138 |
+
tgt2, attn2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
|
139 |
+
key_padding_mask=memory_key_padding_mask)
|
140 |
+
if self.debug: self.attn2 = attn2
|
141 |
+
|
142 |
+
if self.siamese:
|
143 |
+
tgt3, attn3 = self.multihead_attn2(tgt, memory2, memory2, attn_mask=memory_mask2,
|
144 |
+
key_padding_mask=memory_key_padding_mask2)
|
145 |
+
tgt = tgt + self.dropout2(tgt3)
|
146 |
+
if self.debug: self.attn3 = attn3
|
147 |
+
|
148 |
+
tgt = tgt + self.dropout2(tgt2)
|
149 |
+
tgt = self.norm2(tgt)
|
150 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
151 |
+
tgt = tgt + self.dropout3(tgt2)
|
152 |
+
tgt = self.norm3(tgt)
|
153 |
+
|
154 |
+
return tgt
|
155 |
+
|
156 |
+
|
157 |
+
class PositionalEncoding(nn.Module):
|
158 |
+
r"""Inject some information about the relative or absolute position of the tokens
|
159 |
+
in the sequence. The positional encodings have the same dimension as
|
160 |
+
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
161 |
+
functions of different frequencies.
|
162 |
+
.. math::
|
163 |
+
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
|
164 |
+
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
|
165 |
+
\text{where pos is the word position and i is the embed idx)
|
166 |
+
Args:
|
167 |
+
d_model: the embed dim (required).
|
168 |
+
dropout: the dropout value (default=0.1).
|
169 |
+
max_len: the max. length of the incoming sequence (default=5000).
|
170 |
+
Examples:
|
171 |
+
>>> pos_encoder = PositionalEncoding(d_model)
|
172 |
+
"""
|
173 |
+
|
174 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
175 |
+
super().__init__()
|
176 |
+
self.dropout = nn.Dropout(p=dropout)
|
177 |
+
|
178 |
+
pe = torch.zeros(max_len, d_model)
|
179 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
180 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
181 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
182 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
183 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
184 |
+
self.register_buffer('pe', pe)
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
r"""Inputs of forward function
|
188 |
+
Args:
|
189 |
+
x: the sequence fed to the positional encoder model (required).
|
190 |
+
Shape:
|
191 |
+
x: [sequence length, batch size, embed dim]
|
192 |
+
output: [sequence length, batch size, embed dim]
|
193 |
+
Examples:
|
194 |
+
>>> output = pos_encoder(x)
|
195 |
+
"""
|
196 |
+
|
197 |
+
x = x + self.pe[:x.size(0), :]
|
198 |
+
return self.dropout(x)
|
IndicPhotoOCR/utils/strhub/models/base.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import math
|
17 |
+
from abc import ABC, abstractmethod
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import Optional
|
20 |
+
|
21 |
+
from nltk import edit_distance
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
from torch import Tensor
|
26 |
+
from torch.optim import Optimizer
|
27 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
28 |
+
|
29 |
+
import pytorch_lightning as pl
|
30 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
31 |
+
from timm.optim import create_optimizer_v2
|
32 |
+
|
33 |
+
from IndicPhotoOCR.utils.strhub.data.utils import BaseTokenizer, CharsetAdapter, CTCTokenizer, Tokenizer
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class BatchResult:
|
38 |
+
num_samples: int
|
39 |
+
correct: int
|
40 |
+
ned: float
|
41 |
+
confidence: float
|
42 |
+
label_length: int
|
43 |
+
loss: Tensor
|
44 |
+
loss_numel: int
|
45 |
+
|
46 |
+
|
47 |
+
EPOCH_OUTPUT = list[dict[str, BatchResult]]
|
48 |
+
|
49 |
+
|
50 |
+
class BaseSystem(pl.LightningModule, ABC):
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
tokenizer: BaseTokenizer,
|
55 |
+
charset_test: str,
|
56 |
+
batch_size: int,
|
57 |
+
lr: float,
|
58 |
+
warmup_pct: float,
|
59 |
+
weight_decay: float,
|
60 |
+
) -> None:
|
61 |
+
super().__init__()
|
62 |
+
self.tokenizer = tokenizer
|
63 |
+
self.charset_adapter = CharsetAdapter(charset_test)
|
64 |
+
self.batch_size = batch_size
|
65 |
+
self.lr = lr
|
66 |
+
self.warmup_pct = warmup_pct
|
67 |
+
self.weight_decay = weight_decay
|
68 |
+
self.outputs: EPOCH_OUTPUT = []
|
69 |
+
|
70 |
+
@abstractmethod
|
71 |
+
def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
|
72 |
+
"""Inference
|
73 |
+
|
74 |
+
Args:
|
75 |
+
images: Batch of images. Shape: N, Ch, H, W
|
76 |
+
max_length: Max sequence length of the output. If None, will use default.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
|
80 |
+
"""
|
81 |
+
raise NotImplementedError
|
82 |
+
|
83 |
+
@abstractmethod
|
84 |
+
def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
|
85 |
+
"""Like forward(), but also computes the loss (calls forward() internally).
|
86 |
+
|
87 |
+
Args:
|
88 |
+
images: Batch of images. Shape: N, Ch, H, W
|
89 |
+
labels: Text labels of the images
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
|
93 |
+
loss: mean loss for the batch
|
94 |
+
loss_numel: number of elements the loss was calculated from
|
95 |
+
"""
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
def configure_optimizers(self):
|
99 |
+
agb = self.trainer.accumulate_grad_batches
|
100 |
+
# Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
|
101 |
+
lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.0
|
102 |
+
lr = lr_scale * self.lr
|
103 |
+
optim = create_optimizer_v2(self, 'adamw', lr, self.weight_decay)
|
104 |
+
sched = OneCycleLR(
|
105 |
+
optim, lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, cycle_momentum=False
|
106 |
+
)
|
107 |
+
return {'optimizer': optim, 'lr_scheduler': {'scheduler': sched, 'interval': 'step'}}
|
108 |
+
|
109 |
+
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer) -> None:
|
110 |
+
optimizer.zero_grad(set_to_none=True)
|
111 |
+
|
112 |
+
def _eval_step(self, batch, validation: bool) -> Optional[STEP_OUTPUT]:
|
113 |
+
images, labels = batch
|
114 |
+
|
115 |
+
correct = 0
|
116 |
+
total = 0
|
117 |
+
ned = 0
|
118 |
+
confidence = 0
|
119 |
+
label_length = 0
|
120 |
+
if validation:
|
121 |
+
logits, loss, loss_numel = self.forward_logits_loss(images, labels)
|
122 |
+
else:
|
123 |
+
# At test-time, we shouldn't specify a max_label_length because the test-time charset used
|
124 |
+
# might be different from the train-time charset. max_label_length in eval_logits_loss() is computed
|
125 |
+
# based on the transformed label, which could be wrong if the actual gt label contains characters existing
|
126 |
+
# in the train-time charset but not in the test-time charset. For example, "aishahaleyes.blogspot.com"
|
127 |
+
# is exactly 25 characters, but if processed by CharsetAdapter for the 36-char set, it becomes 23 characters
|
128 |
+
# long only, which sets max_label_length = 23. This will cause the model prediction to be truncated.
|
129 |
+
logits = self.forward(images)
|
130 |
+
loss = loss_numel = None # Only used for validation; not needed at test-time.
|
131 |
+
|
132 |
+
probs = logits.softmax(-1)
|
133 |
+
preds, probs = self.tokenizer.decode(probs)
|
134 |
+
for pred, prob, gt in zip(preds, probs, labels):
|
135 |
+
confidence += prob.prod().item()
|
136 |
+
pred = self.charset_adapter(pred)
|
137 |
+
# Follow ICDAR 2019 definition of N.E.D.
|
138 |
+
ned += edit_distance(pred, gt) / max(len(pred), len(gt))
|
139 |
+
if pred == gt:
|
140 |
+
correct += 1
|
141 |
+
total += 1
|
142 |
+
label_length += len(pred)
|
143 |
+
return dict(output=BatchResult(total, correct, ned, confidence, label_length, loss, loss_numel))
|
144 |
+
|
145 |
+
@staticmethod
|
146 |
+
def _aggregate_results(outputs: EPOCH_OUTPUT) -> tuple[float, float, float]:
|
147 |
+
if not outputs:
|
148 |
+
return 0.0, 0.0, 0.0
|
149 |
+
total_loss = 0
|
150 |
+
total_loss_numel = 0
|
151 |
+
total_n_correct = 0
|
152 |
+
total_norm_ED = 0
|
153 |
+
total_size = 0
|
154 |
+
for result in outputs:
|
155 |
+
result = result['output']
|
156 |
+
total_loss += result.loss_numel * result.loss
|
157 |
+
total_loss_numel += result.loss_numel
|
158 |
+
total_n_correct += result.correct
|
159 |
+
total_norm_ED += result.ned
|
160 |
+
total_size += result.num_samples
|
161 |
+
acc = total_n_correct / total_size
|
162 |
+
ned = 1 - total_norm_ED / total_size
|
163 |
+
loss = total_loss / total_loss_numel
|
164 |
+
return acc, ned, loss
|
165 |
+
|
166 |
+
def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
|
167 |
+
result = self._eval_step(batch, True)
|
168 |
+
self.outputs.append(result)
|
169 |
+
return result
|
170 |
+
|
171 |
+
def on_validation_epoch_end(self) -> None:
|
172 |
+
acc, ned, loss = self._aggregate_results(self.outputs)
|
173 |
+
self.outputs.clear()
|
174 |
+
self.log('val_accuracy', 100 * acc, sync_dist=True)
|
175 |
+
self.log('val_NED', 100 * ned, sync_dist=True)
|
176 |
+
self.log('val_loss', loss, sync_dist=True)
|
177 |
+
self.log('hp_metric', acc, sync_dist=True)
|
178 |
+
|
179 |
+
def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
|
180 |
+
return self._eval_step(batch, False)
|
181 |
+
|
182 |
+
|
183 |
+
class CrossEntropySystem(BaseSystem):
|
184 |
+
|
185 |
+
def __init__(
|
186 |
+
self, charset_train: str, charset_test: str, batch_size: int, lr: float, warmup_pct: float, weight_decay: float
|
187 |
+
) -> None:
|
188 |
+
tokenizer = Tokenizer(charset_train)
|
189 |
+
super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
|
190 |
+
self.bos_id = tokenizer.bos_id
|
191 |
+
self.eos_id = tokenizer.eos_id
|
192 |
+
self.pad_id = tokenizer.pad_id
|
193 |
+
|
194 |
+
def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
|
195 |
+
targets = self.tokenizer.encode(labels, self.device)
|
196 |
+
targets = targets[:, 1:] # Discard <bos>
|
197 |
+
max_len = targets.shape[1] - 1 # exclude <eos> from count
|
198 |
+
logits = self.forward(images, max_len)
|
199 |
+
loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id)
|
200 |
+
loss_numel = (targets != self.pad_id).sum()
|
201 |
+
return logits, loss, loss_numel
|
202 |
+
|
203 |
+
|
204 |
+
class CTCSystem(BaseSystem):
|
205 |
+
|
206 |
+
def __init__(
|
207 |
+
self, charset_train: str, charset_test: str, batch_size: int, lr: float, warmup_pct: float, weight_decay: float
|
208 |
+
) -> None:
|
209 |
+
tokenizer = CTCTokenizer(charset_train)
|
210 |
+
super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
|
211 |
+
self.blank_id = tokenizer.blank_id
|
212 |
+
|
213 |
+
def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
|
214 |
+
targets = self.tokenizer.encode(labels, self.device)
|
215 |
+
logits = self.forward(images)
|
216 |
+
log_probs = logits.log_softmax(-1).transpose(0, 1) # swap batch and seq. dims
|
217 |
+
T, N, _ = log_probs.shape
|
218 |
+
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long, device=self.device)
|
219 |
+
target_lengths = torch.as_tensor(list(map(len, labels)), dtype=torch.long, device=self.device)
|
220 |
+
loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_id, zero_infinity=True)
|
221 |
+
return logits, loss, N
|
IndicPhotoOCR/utils/strhub/models/crnn/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The MIT License (MIT)
|
2 |
+
|
3 |
+
Copyright (c) 2017 Jieru Mei <meijieru@gmail.com>
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
IndicPhotoOCR/utils/strhub/models/crnn/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""
|
2 |
+
Shi, Baoguang, Xiang Bai, and Cong Yao.
|
3 |
+
"An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition."
|
4 |
+
IEEE transactions on pattern analysis and machine intelligence 39, no. 11 (2016): 2298-2304.
|
5 |
+
|
6 |
+
https://arxiv.org/abs/1507.05717
|
7 |
+
|
8 |
+
All source files, except `system.py`, are based on the implementation listed below,
|
9 |
+
and hence are released under the license of the original.
|
10 |
+
|
11 |
+
Source: https://github.com/meijieru/crnn.pytorch
|
12 |
+
License: MIT License (see included LICENSE file)
|
13 |
+
"""
|
IndicPhotoOCR/utils/strhub/models/crnn/model.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from strhub.models.modules import BidirectionalLSTM
|
4 |
+
|
5 |
+
|
6 |
+
class CRNN(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, img_h, nc, nclass, nh, leaky_relu=False):
|
9 |
+
super().__init__()
|
10 |
+
assert img_h % 16 == 0, 'img_h has to be a multiple of 16'
|
11 |
+
|
12 |
+
ks = [3, 3, 3, 3, 3, 3, 2]
|
13 |
+
ps = [1, 1, 1, 1, 1, 1, 0]
|
14 |
+
ss = [1, 1, 1, 1, 1, 1, 1]
|
15 |
+
nm = [64, 128, 256, 256, 512, 512, 512]
|
16 |
+
|
17 |
+
cnn = nn.Sequential()
|
18 |
+
|
19 |
+
def convRelu(i, batchNormalization=False):
|
20 |
+
nIn = nc if i == 0 else nm[i - 1]
|
21 |
+
nOut = nm[i]
|
22 |
+
cnn.add_module(f'conv{i}',
|
23 |
+
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i], bias=not batchNormalization))
|
24 |
+
if batchNormalization:
|
25 |
+
cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(nOut))
|
26 |
+
if leaky_relu:
|
27 |
+
cnn.add_module(f'relu{i}',
|
28 |
+
nn.LeakyReLU(0.2, inplace=True))
|
29 |
+
else:
|
30 |
+
cnn.add_module(f'relu{i}', nn.ReLU(True))
|
31 |
+
|
32 |
+
convRelu(0)
|
33 |
+
cnn.add_module('pooling0', nn.MaxPool2d(2, 2)) # 64x16x64
|
34 |
+
convRelu(1)
|
35 |
+
cnn.add_module('pooling1', nn.MaxPool2d(2, 2)) # 128x8x32
|
36 |
+
convRelu(2, True)
|
37 |
+
convRelu(3)
|
38 |
+
cnn.add_module('pooling2',
|
39 |
+
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
|
40 |
+
convRelu(4, True)
|
41 |
+
convRelu(5)
|
42 |
+
cnn.add_module('pooling3',
|
43 |
+
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
|
44 |
+
convRelu(6, True) # 512x1x16
|
45 |
+
|
46 |
+
self.cnn = cnn
|
47 |
+
self.rnn = nn.Sequential(
|
48 |
+
BidirectionalLSTM(512, nh, nh),
|
49 |
+
BidirectionalLSTM(nh, nh, nclass))
|
50 |
+
|
51 |
+
def forward(self, input):
|
52 |
+
# conv features
|
53 |
+
conv = self.cnn(input)
|
54 |
+
b, c, h, w = conv.size()
|
55 |
+
assert h == 1, 'the height of conv must be 1'
|
56 |
+
conv = conv.squeeze(2)
|
57 |
+
conv = conv.transpose(1, 2) # [b, w, c]
|
58 |
+
|
59 |
+
# rnn features
|
60 |
+
output = self.rnn(conv)
|
61 |
+
|
62 |
+
return output
|
IndicPhotoOCR/utils/strhub/models/crnn/system.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Optional, Sequence
|
17 |
+
|
18 |
+
from torch import Tensor
|
19 |
+
|
20 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
21 |
+
|
22 |
+
from strhub.models.base import CTCSystem
|
23 |
+
from strhub.models.utils import init_weights
|
24 |
+
|
25 |
+
from .model import CRNN as Model
|
26 |
+
|
27 |
+
|
28 |
+
class CRNN(CTCSystem):
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
charset_train: str,
|
33 |
+
charset_test: str,
|
34 |
+
max_label_length: int,
|
35 |
+
batch_size: int,
|
36 |
+
lr: float,
|
37 |
+
warmup_pct: float,
|
38 |
+
weight_decay: float,
|
39 |
+
img_size: Sequence[int],
|
40 |
+
hidden_size: int,
|
41 |
+
leaky_relu: bool,
|
42 |
+
**kwargs,
|
43 |
+
) -> None:
|
44 |
+
super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
|
45 |
+
self.save_hyperparameters()
|
46 |
+
self.model = Model(img_size[0], 3, len(self.tokenizer), hidden_size, leaky_relu)
|
47 |
+
self.model.apply(init_weights)
|
48 |
+
|
49 |
+
def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
|
50 |
+
return self.model.forward(images)
|
51 |
+
|
52 |
+
def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
53 |
+
images, labels = batch
|
54 |
+
loss = self.forward_logits_loss(images, labels)[1]
|
55 |
+
self.log('loss', loss)
|
56 |
+
return loss
|
IndicPhotoOCR/utils/strhub/models/modules.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""Shared modules used by CRNN and TRBA"""
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class BidirectionalLSTM(nn.Module):
|
6 |
+
"""Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py"""
|
7 |
+
|
8 |
+
def __init__(self, input_size, hidden_size, output_size):
|
9 |
+
super().__init__()
|
10 |
+
self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
|
11 |
+
self.linear = nn.Linear(hidden_size * 2, output_size)
|
12 |
+
|
13 |
+
def forward(self, input):
|
14 |
+
"""
|
15 |
+
input : visual feature [batch_size x T x input_size], T = num_steps.
|
16 |
+
output : contextual feature [batch_size x T x output_size]
|
17 |
+
"""
|
18 |
+
recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
|
19 |
+
output = self.linear(recurrent) # batch_size x T x output_size
|
20 |
+
return output
|
IndicPhotoOCR/utils/strhub/models/parseq/__init__.py
ADDED
File without changes
|
IndicPhotoOCR/utils/strhub/models/parseq/model.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from functools import partial
|
17 |
+
from typing import Optional, Sequence
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from torch import Tensor
|
22 |
+
|
23 |
+
from timm.models.helpers import named_apply
|
24 |
+
|
25 |
+
from IndicPhotoOCR.utils.strhub.data.utils import Tokenizer
|
26 |
+
from IndicPhotoOCR.utils.strhub.models.utils import init_weights
|
27 |
+
|
28 |
+
from .modules import Decoder, DecoderLayer, Encoder, TokenEmbedding
|
29 |
+
|
30 |
+
|
31 |
+
class PARSeq(nn.Module):
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
num_tokens: int,
|
36 |
+
max_label_length: int,
|
37 |
+
img_size: Sequence[int],
|
38 |
+
patch_size: Sequence[int],
|
39 |
+
embed_dim: int,
|
40 |
+
enc_num_heads: int,
|
41 |
+
enc_mlp_ratio: int,
|
42 |
+
enc_depth: int,
|
43 |
+
dec_num_heads: int,
|
44 |
+
dec_mlp_ratio: int,
|
45 |
+
dec_depth: int,
|
46 |
+
decode_ar: bool,
|
47 |
+
refine_iters: int,
|
48 |
+
dropout: float,
|
49 |
+
) -> None:
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.max_label_length = max_label_length
|
53 |
+
self.decode_ar = decode_ar
|
54 |
+
self.refine_iters = refine_iters
|
55 |
+
|
56 |
+
self.encoder = Encoder(
|
57 |
+
img_size, patch_size, embed_dim=embed_dim, depth=enc_depth, num_heads=enc_num_heads, mlp_ratio=enc_mlp_ratio
|
58 |
+
)
|
59 |
+
decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
|
60 |
+
self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=nn.LayerNorm(embed_dim))
|
61 |
+
|
62 |
+
# We don't predict <bos> nor <pad>
|
63 |
+
self.head = nn.Linear(embed_dim, num_tokens - 2)
|
64 |
+
self.text_embed = TokenEmbedding(num_tokens, embed_dim)
|
65 |
+
|
66 |
+
# +1 for <eos>
|
67 |
+
self.pos_queries = nn.Parameter(torch.Tensor(1, max_label_length + 1, embed_dim))
|
68 |
+
self.dropout = nn.Dropout(p=dropout)
|
69 |
+
# Encoder has its own init.
|
70 |
+
named_apply(partial(init_weights, exclude=['encoder']), self)
|
71 |
+
nn.init.trunc_normal_(self.pos_queries, std=0.02)
|
72 |
+
|
73 |
+
@property
|
74 |
+
def _device(self) -> torch.device:
|
75 |
+
return next(self.head.parameters(recurse=False)).device
|
76 |
+
|
77 |
+
@torch.jit.ignore
|
78 |
+
def no_weight_decay(self):
|
79 |
+
param_names = {'text_embed.embedding.weight', 'pos_queries'}
|
80 |
+
enc_param_names = {'encoder.' + n for n in self.encoder.no_weight_decay()}
|
81 |
+
return param_names.union(enc_param_names)
|
82 |
+
|
83 |
+
def encode(self, img: torch.Tensor):
|
84 |
+
return self.encoder(img)
|
85 |
+
|
86 |
+
def decode(
|
87 |
+
self,
|
88 |
+
tgt: torch.Tensor,
|
89 |
+
memory: torch.Tensor,
|
90 |
+
tgt_mask: Optional[Tensor] = None,
|
91 |
+
tgt_padding_mask: Optional[Tensor] = None,
|
92 |
+
tgt_query: Optional[Tensor] = None,
|
93 |
+
tgt_query_mask: Optional[Tensor] = None,
|
94 |
+
):
|
95 |
+
N, L = tgt.shape
|
96 |
+
# <bos> stands for the null context. We only supply position information for characters after <bos>.
|
97 |
+
null_ctx = self.text_embed(tgt[:, :1])
|
98 |
+
tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:])
|
99 |
+
tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
|
100 |
+
if tgt_query is None:
|
101 |
+
tgt_query = self.pos_queries[:, :L].expand(N, -1, -1)
|
102 |
+
tgt_query = self.dropout(tgt_query)
|
103 |
+
return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask)
|
104 |
+
|
105 |
+
def forward(self, tokenizer: Tokenizer, images: Tensor, max_length: Optional[int] = None) -> Tensor:
|
106 |
+
testing = max_length is None
|
107 |
+
max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
|
108 |
+
bs = images.shape[0]
|
109 |
+
# +1 for <eos> at end of sequence.
|
110 |
+
num_steps = max_length + 1
|
111 |
+
memory = self.encode(images)
|
112 |
+
|
113 |
+
# Query positions up to `num_steps`
|
114 |
+
pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
|
115 |
+
|
116 |
+
# Special case for the forward permutation. Faster than using `generate_attn_masks()`
|
117 |
+
tgt_mask = query_mask = torch.triu(torch.ones((num_steps, num_steps), dtype=torch.bool, device=self._device), 1)
|
118 |
+
|
119 |
+
if self.decode_ar:
|
120 |
+
tgt_in = torch.full((bs, num_steps), tokenizer.pad_id, dtype=torch.long, device=self._device)
|
121 |
+
tgt_in[:, 0] = tokenizer.bos_id
|
122 |
+
|
123 |
+
logits = []
|
124 |
+
for i in range(num_steps):
|
125 |
+
j = i + 1 # next token index
|
126 |
+
# Efficient decoding:
|
127 |
+
# Input the context up to the ith token. We use only one query (at position = i) at a time.
|
128 |
+
# This works because of the lookahead masking effect of the canonical (forward) AR context.
|
129 |
+
# Past tokens have no access to future tokens, hence are fixed once computed.
|
130 |
+
tgt_out = self.decode(
|
131 |
+
tgt_in[:, :j],
|
132 |
+
memory,
|
133 |
+
tgt_mask[:j, :j],
|
134 |
+
tgt_query=pos_queries[:, i:j],
|
135 |
+
tgt_query_mask=query_mask[i:j, :j],
|
136 |
+
)
|
137 |
+
# the next token probability is in the output's ith token position
|
138 |
+
p_i = self.head(tgt_out)
|
139 |
+
logits.append(p_i)
|
140 |
+
if j < num_steps:
|
141 |
+
# greedy decode. add the next token index to the target input
|
142 |
+
tgt_in[:, j] = p_i.squeeze().argmax(-1)
|
143 |
+
# Efficient batch decoding: If all output words have at least one EOS token, end decoding.
|
144 |
+
if testing and (tgt_in == tokenizer.eos_id).any(dim=-1).all():
|
145 |
+
break
|
146 |
+
|
147 |
+
logits = torch.cat(logits, dim=1)
|
148 |
+
else:
|
149 |
+
# No prior context, so input is just <bos>. We query all positions.
|
150 |
+
tgt_in = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device)
|
151 |
+
tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
|
152 |
+
logits = self.head(tgt_out)
|
153 |
+
|
154 |
+
if self.refine_iters:
|
155 |
+
# For iterative refinement, we always use a 'cloze' mask.
|
156 |
+
# We can derive it from the AR forward mask by unmasking the token context to the right.
|
157 |
+
query_mask[torch.triu(torch.ones(num_steps, num_steps, dtype=torch.bool, device=self._device), 2)] = 0
|
158 |
+
bos = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device)
|
159 |
+
for i in range(self.refine_iters):
|
160 |
+
# Prior context is the previous output.
|
161 |
+
tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
|
162 |
+
# Mask tokens beyond the first EOS token.
|
163 |
+
tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(-1) > 0
|
164 |
+
tgt_out = self.decode(
|
165 |
+
tgt_in, memory, tgt_mask, tgt_padding_mask, pos_queries, query_mask[:, : tgt_in.shape[1]]
|
166 |
+
)
|
167 |
+
logits = self.head(tgt_out)
|
168 |
+
|
169 |
+
return logits
|
IndicPhotoOCR/utils/strhub/models/parseq/modules.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import math
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import Tensor, nn as nn
|
21 |
+
from torch.nn import functional as F
|
22 |
+
from torch.nn.modules import transformer
|
23 |
+
|
24 |
+
from timm.models.vision_transformer import PatchEmbed, VisionTransformer
|
25 |
+
|
26 |
+
|
27 |
+
class DecoderLayer(nn.Module):
|
28 |
+
"""A Transformer decoder layer supporting two-stream attention (XLNet)
|
29 |
+
This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
|
30 |
+
|
31 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', layer_norm_eps=1e-5):
|
32 |
+
super().__init__()
|
33 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
|
34 |
+
self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
|
35 |
+
# Implementation of Feedforward model
|
36 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
37 |
+
self.dropout = nn.Dropout(dropout)
|
38 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
39 |
+
|
40 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
41 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
42 |
+
self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
43 |
+
self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
44 |
+
self.dropout1 = nn.Dropout(dropout)
|
45 |
+
self.dropout2 = nn.Dropout(dropout)
|
46 |
+
self.dropout3 = nn.Dropout(dropout)
|
47 |
+
|
48 |
+
self.activation = transformer._get_activation_fn(activation)
|
49 |
+
|
50 |
+
def __setstate__(self, state):
|
51 |
+
if 'activation' not in state:
|
52 |
+
state['activation'] = F.gelu
|
53 |
+
super().__setstate__(state)
|
54 |
+
|
55 |
+
def forward_stream(
|
56 |
+
self,
|
57 |
+
tgt: Tensor,
|
58 |
+
tgt_norm: Tensor,
|
59 |
+
tgt_kv: Tensor,
|
60 |
+
memory: Tensor,
|
61 |
+
tgt_mask: Optional[Tensor],
|
62 |
+
tgt_key_padding_mask: Optional[Tensor],
|
63 |
+
):
|
64 |
+
"""Forward pass for a single stream (i.e. content or query)
|
65 |
+
tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
|
66 |
+
Both tgt_kv and memory are expected to be LayerNorm'd too.
|
67 |
+
memory is LayerNorm'd by ViT.
|
68 |
+
"""
|
69 |
+
tgt2, sa_weights = self.self_attn(
|
70 |
+
tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
71 |
+
)
|
72 |
+
tgt = tgt + self.dropout1(tgt2)
|
73 |
+
|
74 |
+
tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
|
75 |
+
tgt = tgt + self.dropout2(tgt2)
|
76 |
+
|
77 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
|
78 |
+
tgt = tgt + self.dropout3(tgt2)
|
79 |
+
return tgt, sa_weights, ca_weights
|
80 |
+
|
81 |
+
def forward(
|
82 |
+
self,
|
83 |
+
query,
|
84 |
+
content,
|
85 |
+
memory,
|
86 |
+
query_mask: Optional[Tensor] = None,
|
87 |
+
content_mask: Optional[Tensor] = None,
|
88 |
+
content_key_padding_mask: Optional[Tensor] = None,
|
89 |
+
update_content: bool = True,
|
90 |
+
):
|
91 |
+
query_norm = self.norm_q(query)
|
92 |
+
content_norm = self.norm_c(content)
|
93 |
+
query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0]
|
94 |
+
if update_content:
|
95 |
+
content = self.forward_stream(
|
96 |
+
content, content_norm, content_norm, memory, content_mask, content_key_padding_mask
|
97 |
+
)[0]
|
98 |
+
return query, content
|
99 |
+
|
100 |
+
|
101 |
+
class Decoder(nn.Module):
|
102 |
+
__constants__ = ['norm']
|
103 |
+
|
104 |
+
def __init__(self, decoder_layer, num_layers, norm):
|
105 |
+
super().__init__()
|
106 |
+
self.layers = transformer._get_clones(decoder_layer, num_layers)
|
107 |
+
self.num_layers = num_layers
|
108 |
+
self.norm = norm
|
109 |
+
|
110 |
+
def forward(
|
111 |
+
self,
|
112 |
+
query,
|
113 |
+
content,
|
114 |
+
memory,
|
115 |
+
query_mask: Optional[Tensor] = None,
|
116 |
+
content_mask: Optional[Tensor] = None,
|
117 |
+
content_key_padding_mask: Optional[Tensor] = None,
|
118 |
+
):
|
119 |
+
for i, mod in enumerate(self.layers):
|
120 |
+
last = i == len(self.layers) - 1
|
121 |
+
query, content = mod(
|
122 |
+
query, content, memory, query_mask, content_mask, content_key_padding_mask, update_content=not last
|
123 |
+
)
|
124 |
+
query = self.norm(query)
|
125 |
+
return query
|
126 |
+
|
127 |
+
|
128 |
+
class Encoder(VisionTransformer):
|
129 |
+
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
img_size=224,
|
133 |
+
patch_size=16,
|
134 |
+
in_chans=3,
|
135 |
+
embed_dim=768,
|
136 |
+
depth=12,
|
137 |
+
num_heads=12,
|
138 |
+
mlp_ratio=4.0,
|
139 |
+
qkv_bias=True,
|
140 |
+
drop_rate=0.0,
|
141 |
+
attn_drop_rate=0.0,
|
142 |
+
drop_path_rate=0.0,
|
143 |
+
embed_layer=PatchEmbed,
|
144 |
+
):
|
145 |
+
super().__init__(
|
146 |
+
img_size,
|
147 |
+
patch_size,
|
148 |
+
in_chans,
|
149 |
+
embed_dim=embed_dim,
|
150 |
+
depth=depth,
|
151 |
+
num_heads=num_heads,
|
152 |
+
mlp_ratio=mlp_ratio,
|
153 |
+
qkv_bias=qkv_bias,
|
154 |
+
drop_rate=drop_rate,
|
155 |
+
attn_drop_rate=attn_drop_rate,
|
156 |
+
drop_path_rate=drop_path_rate,
|
157 |
+
embed_layer=embed_layer,
|
158 |
+
num_classes=0, # These
|
159 |
+
global_pool='', # disable the
|
160 |
+
class_token=False, # classifier head.
|
161 |
+
)
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
# Return all tokens
|
165 |
+
return self.forward_features(x)
|
166 |
+
|
167 |
+
|
168 |
+
class TokenEmbedding(nn.Module):
|
169 |
+
|
170 |
+
def __init__(self, charset_size: int, embed_dim: int):
|
171 |
+
super().__init__()
|
172 |
+
self.embedding = nn.Embedding(charset_size, embed_dim)
|
173 |
+
self.embed_dim = embed_dim
|
174 |
+
|
175 |
+
def forward(self, tokens: torch.Tensor):
|
176 |
+
return math.sqrt(self.embed_dim) * self.embedding(tokens)
|
IndicPhotoOCR/utils/strhub/models/parseq/system.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import math
|
17 |
+
from itertools import permutations
|
18 |
+
from typing import Any, Optional, Sequence
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from torch import Tensor
|
25 |
+
|
26 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
27 |
+
|
28 |
+
from IndicPhotoOCR.utils.strhub.models.base import CrossEntropySystem
|
29 |
+
|
30 |
+
from .model import PARSeq as Model
|
31 |
+
|
32 |
+
|
33 |
+
class PARSeq(CrossEntropySystem):
|
34 |
+
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
charset_train: str,
|
38 |
+
charset_test: str,
|
39 |
+
max_label_length: int,
|
40 |
+
batch_size: int,
|
41 |
+
lr: float,
|
42 |
+
warmup_pct: float,
|
43 |
+
weight_decay: float,
|
44 |
+
img_size: Sequence[int],
|
45 |
+
patch_size: Sequence[int],
|
46 |
+
embed_dim: int,
|
47 |
+
enc_num_heads: int,
|
48 |
+
enc_mlp_ratio: int,
|
49 |
+
enc_depth: int,
|
50 |
+
dec_num_heads: int,
|
51 |
+
dec_mlp_ratio: int,
|
52 |
+
dec_depth: int,
|
53 |
+
perm_num: int,
|
54 |
+
perm_forward: bool,
|
55 |
+
perm_mirrored: bool,
|
56 |
+
decode_ar: bool,
|
57 |
+
refine_iters: int,
|
58 |
+
dropout: float,
|
59 |
+
**kwargs: Any,
|
60 |
+
) -> None:
|
61 |
+
super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
|
62 |
+
self.save_hyperparameters()
|
63 |
+
|
64 |
+
self.model = Model(
|
65 |
+
len(self.tokenizer),
|
66 |
+
max_label_length,
|
67 |
+
img_size,
|
68 |
+
patch_size,
|
69 |
+
embed_dim,
|
70 |
+
enc_num_heads,
|
71 |
+
enc_mlp_ratio,
|
72 |
+
enc_depth,
|
73 |
+
dec_num_heads,
|
74 |
+
dec_mlp_ratio,
|
75 |
+
dec_depth,
|
76 |
+
decode_ar,
|
77 |
+
refine_iters,
|
78 |
+
dropout,
|
79 |
+
)
|
80 |
+
|
81 |
+
# Perm/attn mask stuff
|
82 |
+
self.rng = np.random.default_rng()
|
83 |
+
self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
|
84 |
+
self.perm_forward = perm_forward
|
85 |
+
self.perm_mirrored = perm_mirrored
|
86 |
+
|
87 |
+
def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
|
88 |
+
return self.model.forward(self.tokenizer, images, max_length)
|
89 |
+
|
90 |
+
def gen_tgt_perms(self, tgt):
|
91 |
+
"""Generate shared permutations for the whole batch.
|
92 |
+
This works because the same attention mask can be used for the shorter sequences
|
93 |
+
because of the padding mask.
|
94 |
+
"""
|
95 |
+
# We don't permute the position of BOS, we permute EOS separately
|
96 |
+
max_num_chars = tgt.shape[1] - 2
|
97 |
+
# Special handling for 1-character sequences
|
98 |
+
if max_num_chars == 1:
|
99 |
+
return torch.arange(3, device=self._device).unsqueeze(0)
|
100 |
+
perms = [torch.arange(max_num_chars, device=self._device)] if self.perm_forward else []
|
101 |
+
# Additional permutations if needed
|
102 |
+
max_perms = math.factorial(max_num_chars)
|
103 |
+
if self.perm_mirrored:
|
104 |
+
max_perms //= 2
|
105 |
+
num_gen_perms = min(self.max_gen_perms, max_perms)
|
106 |
+
# For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions
|
107 |
+
# Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars.
|
108 |
+
if max_num_chars < 5:
|
109 |
+
# Pool of permutations to sample from. We only need the first half (if complementary option is selected)
|
110 |
+
# Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
|
111 |
+
if max_num_chars == 4 and self.perm_mirrored:
|
112 |
+
selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
|
113 |
+
else:
|
114 |
+
selector = list(range(max_perms))
|
115 |
+
perm_pool = torch.as_tensor(
|
116 |
+
list(permutations(range(max_num_chars), max_num_chars)),
|
117 |
+
device=self._device,
|
118 |
+
)[selector]
|
119 |
+
# If the forward permutation is always selected, no need to add it to the pool for sampling
|
120 |
+
if self.perm_forward:
|
121 |
+
perm_pool = perm_pool[1:]
|
122 |
+
perms = torch.stack(perms)
|
123 |
+
if len(perm_pool):
|
124 |
+
i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(perms), replace=False)
|
125 |
+
perms = torch.cat([perms, perm_pool[i]])
|
126 |
+
else:
|
127 |
+
perms.extend(
|
128 |
+
[torch.randperm(max_num_chars, device=self._device) for _ in range(num_gen_perms - len(perms))]
|
129 |
+
)
|
130 |
+
perms = torch.stack(perms)
|
131 |
+
if self.perm_mirrored:
|
132 |
+
# Add complementary pairs
|
133 |
+
comp = perms.flip(-1)
|
134 |
+
# Stack in such a way that the pairs are next to each other.
|
135 |
+
perms = torch.stack([perms, comp]).transpose(0, 1).reshape(-1, max_num_chars)
|
136 |
+
# NOTE:
|
137 |
+
# The only meaningful way of permuting the EOS position is by moving it one character position at a time.
|
138 |
+
# However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS
|
139 |
+
# positions will always be much less than the number of permutations (unless a low perm_num is set).
|
140 |
+
# Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly
|
141 |
+
# distribute it across the chosen number of permutations.
|
142 |
+
# Add position indices of BOS and EOS
|
143 |
+
bos_idx = perms.new_zeros((len(perms), 1))
|
144 |
+
eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1)
|
145 |
+
perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1)
|
146 |
+
# Special handling for the reverse direction. This does two things:
|
147 |
+
# 1. Reverse context for the characters
|
148 |
+
# 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode)
|
149 |
+
if len(perms) > 1:
|
150 |
+
perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=self._device)
|
151 |
+
return perms
|
152 |
+
|
153 |
+
def generate_attn_masks(self, perm):
|
154 |
+
"""Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens)
|
155 |
+
:param perm: the permutation sequence. i = 0 is always the BOS
|
156 |
+
:return: lookahead attention masks
|
157 |
+
"""
|
158 |
+
sz = perm.shape[0]
|
159 |
+
mask = torch.zeros((sz, sz), dtype=torch.bool, device=self._device)
|
160 |
+
for i in range(sz):
|
161 |
+
query_idx = perm[i]
|
162 |
+
masked_keys = perm[i + 1 :]
|
163 |
+
mask[query_idx, masked_keys] = True
|
164 |
+
content_mask = mask[:-1, :-1].clone()
|
165 |
+
mask[torch.eye(sz, dtype=torch.bool, device=self._device)] = True # mask "self"
|
166 |
+
query_mask = mask[1:, :-1]
|
167 |
+
return content_mask, query_mask
|
168 |
+
|
169 |
+
def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
170 |
+
images, labels = batch
|
171 |
+
tgt = self.tokenizer.encode(labels, self._device)
|
172 |
+
|
173 |
+
# Encode the source sequence (i.e. the image codes)
|
174 |
+
memory = self.model.encode(images)
|
175 |
+
|
176 |
+
# Prepare the target sequences (input and output)
|
177 |
+
tgt_perms = self.gen_tgt_perms(tgt)
|
178 |
+
tgt_in = tgt[:, :-1]
|
179 |
+
tgt_out = tgt[:, 1:]
|
180 |
+
# The [EOS] token is not depended upon by any other token in any permutation ordering
|
181 |
+
tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
|
182 |
+
|
183 |
+
loss = 0
|
184 |
+
loss_numel = 0
|
185 |
+
n = (tgt_out != self.pad_id).sum().item()
|
186 |
+
for i, perm in enumerate(tgt_perms):
|
187 |
+
tgt_mask, query_mask = self.generate_attn_masks(perm)
|
188 |
+
out = self.model.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask)
|
189 |
+
logits = self.model.head(out).flatten(end_dim=1)
|
190 |
+
loss += n * F.cross_entropy(logits, tgt_out.flatten(), ignore_index=self.pad_id)
|
191 |
+
loss_numel += n
|
192 |
+
# After the second iteration (i.e. done with canonical and reverse orderings),
|
193 |
+
# remove the [EOS] tokens for the succeeding perms
|
194 |
+
if i == 1:
|
195 |
+
tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id, tgt_out)
|
196 |
+
n = (tgt_out != self.pad_id).sum().item()
|
197 |
+
loss /= loss_numel
|
198 |
+
|
199 |
+
self.log('loss', loss)
|
200 |
+
return loss
|
IndicPhotoOCR/utils/strhub/models/trba/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""
|
2 |
+
Baek, Jeonghun, Geewook Kim, Junyeop Lee, Sungrae Park, Dongyoon Han, Sangdoo Yun, Seong Joon Oh, and Hwalsuk Lee.
|
3 |
+
"What is wrong with scene text recognition model comparisons? dataset and model analysis."
|
4 |
+
In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4715-4723. 2019.
|
5 |
+
|
6 |
+
https://arxiv.org/abs/1904.01906
|
7 |
+
|
8 |
+
All source files, except `system.py`, are based on the implementation listed below,
|
9 |
+
and hence are released under the license of the original.
|
10 |
+
|
11 |
+
Source: https://github.com/clovaai/deep-text-recognition-benchmark
|
12 |
+
License: Apache License 2.0 (see LICENSE file in project root)
|
13 |
+
"""
|
IndicPhotoOCR/utils/strhub/models/trba/feature_extraction.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from torchvision.models.resnet import BasicBlock
|
4 |
+
|
5 |
+
|
6 |
+
class ResNet_FeatureExtractor(nn.Module):
|
7 |
+
""" FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
|
8 |
+
|
9 |
+
def __init__(self, input_channel, output_channel=512):
|
10 |
+
super().__init__()
|
11 |
+
self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3])
|
12 |
+
|
13 |
+
def forward(self, input):
|
14 |
+
return self.ConvNet(input)
|
15 |
+
|
16 |
+
|
17 |
+
class ResNet(nn.Module):
|
18 |
+
|
19 |
+
def __init__(self, input_channel, output_channel, block, layers):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
|
23 |
+
|
24 |
+
self.inplanes = int(output_channel / 8)
|
25 |
+
self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
|
26 |
+
kernel_size=3, stride=1, padding=1, bias=False)
|
27 |
+
self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
|
28 |
+
self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
|
29 |
+
kernel_size=3, stride=1, padding=1, bias=False)
|
30 |
+
self.bn0_2 = nn.BatchNorm2d(self.inplanes)
|
31 |
+
self.relu = nn.ReLU(inplace=True)
|
32 |
+
|
33 |
+
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
34 |
+
self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
|
35 |
+
self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
|
36 |
+
0], kernel_size=3, stride=1, padding=1, bias=False)
|
37 |
+
self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
|
38 |
+
|
39 |
+
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
40 |
+
self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
|
41 |
+
self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
|
42 |
+
1], kernel_size=3, stride=1, padding=1, bias=False)
|
43 |
+
self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
|
44 |
+
|
45 |
+
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
|
46 |
+
self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
|
47 |
+
self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
|
48 |
+
2], kernel_size=3, stride=1, padding=1, bias=False)
|
49 |
+
self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
|
50 |
+
|
51 |
+
self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
|
52 |
+
self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
|
53 |
+
3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
|
54 |
+
self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
|
55 |
+
self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
|
56 |
+
3], kernel_size=2, stride=1, padding=0, bias=False)
|
57 |
+
self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
|
58 |
+
|
59 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
60 |
+
downsample = None
|
61 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
62 |
+
downsample = nn.Sequential(
|
63 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
64 |
+
kernel_size=1, stride=stride, bias=False),
|
65 |
+
nn.BatchNorm2d(planes * block.expansion),
|
66 |
+
)
|
67 |
+
|
68 |
+
layers = []
|
69 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
70 |
+
self.inplanes = planes * block.expansion
|
71 |
+
for i in range(1, blocks):
|
72 |
+
layers.append(block(self.inplanes, planes))
|
73 |
+
|
74 |
+
return nn.Sequential(*layers)
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
x = self.conv0_1(x)
|
78 |
+
x = self.bn0_1(x)
|
79 |
+
x = self.relu(x)
|
80 |
+
x = self.conv0_2(x)
|
81 |
+
x = self.bn0_2(x)
|
82 |
+
x = self.relu(x)
|
83 |
+
|
84 |
+
x = self.maxpool1(x)
|
85 |
+
x = self.layer1(x)
|
86 |
+
x = self.conv1(x)
|
87 |
+
x = self.bn1(x)
|
88 |
+
x = self.relu(x)
|
89 |
+
|
90 |
+
x = self.maxpool2(x)
|
91 |
+
x = self.layer2(x)
|
92 |
+
x = self.conv2(x)
|
93 |
+
x = self.bn2(x)
|
94 |
+
x = self.relu(x)
|
95 |
+
|
96 |
+
x = self.maxpool3(x)
|
97 |
+
x = self.layer3(x)
|
98 |
+
x = self.conv3(x)
|
99 |
+
x = self.bn3(x)
|
100 |
+
x = self.relu(x)
|
101 |
+
|
102 |
+
x = self.layer4(x)
|
103 |
+
x = self.conv4_1(x)
|
104 |
+
x = self.bn4_1(x)
|
105 |
+
x = self.relu(x)
|
106 |
+
x = self.conv4_2(x)
|
107 |
+
x = self.bn4_2(x)
|
108 |
+
x = self.relu(x)
|
109 |
+
|
110 |
+
return x
|
IndicPhotoOCR/utils/strhub/models/trba/model.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from strhub.models.modules import BidirectionalLSTM
|
4 |
+
from .feature_extraction import ResNet_FeatureExtractor
|
5 |
+
from .prediction import Attention
|
6 |
+
from .transformation import TPS_SpatialTransformerNetwork
|
7 |
+
|
8 |
+
|
9 |
+
class TRBA(nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, img_h, img_w, num_class, num_fiducial=20, input_channel=3, output_channel=512, hidden_size=256,
|
12 |
+
use_ctc=False):
|
13 |
+
super().__init__()
|
14 |
+
""" Transformation """
|
15 |
+
self.Transformation = TPS_SpatialTransformerNetwork(
|
16 |
+
F=num_fiducial, I_size=(img_h, img_w), I_r_size=(img_h, img_w),
|
17 |
+
I_channel_num=input_channel)
|
18 |
+
|
19 |
+
""" FeatureExtraction """
|
20 |
+
self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel)
|
21 |
+
self.FeatureExtraction_output = output_channel
|
22 |
+
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
|
23 |
+
|
24 |
+
""" Sequence modeling"""
|
25 |
+
self.SequenceModeling = nn.Sequential(
|
26 |
+
BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
|
27 |
+
BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
|
28 |
+
self.SequenceModeling_output = hidden_size
|
29 |
+
|
30 |
+
""" Prediction """
|
31 |
+
if use_ctc:
|
32 |
+
self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)
|
33 |
+
else:
|
34 |
+
self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class)
|
35 |
+
|
36 |
+
def forward(self, image, max_label_length, text=None):
|
37 |
+
""" Transformation stage """
|
38 |
+
image = self.Transformation(image)
|
39 |
+
|
40 |
+
""" Feature extraction stage """
|
41 |
+
visual_feature = self.FeatureExtraction(image)
|
42 |
+
visual_feature = visual_feature.permute(0, 3, 1, 2) # [b, c, h, w] -> [b, w, c, h]
|
43 |
+
visual_feature = self.AdaptiveAvgPool(visual_feature) # [b, w, c, h] -> [b, w, c, 1]
|
44 |
+
visual_feature = visual_feature.squeeze(3) # [b, w, c, 1] -> [b, w, c]
|
45 |
+
|
46 |
+
""" Sequence modeling stage """
|
47 |
+
contextual_feature = self.SequenceModeling(visual_feature) # [b, num_steps, hidden_size]
|
48 |
+
|
49 |
+
""" Prediction stage """
|
50 |
+
if isinstance(self.Prediction, Attention):
|
51 |
+
prediction = self.Prediction(contextual_feature.contiguous(), text, max_label_length)
|
52 |
+
else:
|
53 |
+
prediction = self.Prediction(contextual_feature.contiguous()) # CTC
|
54 |
+
|
55 |
+
return prediction # [b, num_steps, num_class]
|
IndicPhotoOCR/utils/strhub/models/trba/prediction.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class Attention(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, input_size, hidden_size, num_class, num_char_embeddings=256):
|
9 |
+
super().__init__()
|
10 |
+
self.attention_cell = AttentionCell(input_size, hidden_size, num_char_embeddings)
|
11 |
+
self.hidden_size = hidden_size
|
12 |
+
self.num_class = num_class
|
13 |
+
self.generator = nn.Linear(hidden_size, num_class)
|
14 |
+
self.char_embeddings = nn.Embedding(num_class, num_char_embeddings)
|
15 |
+
|
16 |
+
def forward(self, batch_H, text, max_label_length=25):
|
17 |
+
"""
|
18 |
+
input:
|
19 |
+
batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_class]
|
20 |
+
text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [SOS] token. text[:, 0] = [SOS].
|
21 |
+
output: probability distribution at each step [batch_size x num_steps x num_class]
|
22 |
+
"""
|
23 |
+
batch_size = batch_H.size(0)
|
24 |
+
num_steps = max_label_length + 1 # +1 for [EOS] at end of sentence.
|
25 |
+
|
26 |
+
output_hiddens = batch_H.new_zeros((batch_size, num_steps, self.hidden_size), dtype=torch.float)
|
27 |
+
hidden = (batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float),
|
28 |
+
batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float))
|
29 |
+
|
30 |
+
if self.training:
|
31 |
+
for i in range(num_steps):
|
32 |
+
char_embeddings = self.char_embeddings(text[:, i])
|
33 |
+
# hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_embeddings : f(y_{t-1})
|
34 |
+
hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings)
|
35 |
+
output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell)
|
36 |
+
probs = self.generator(output_hiddens)
|
37 |
+
|
38 |
+
else:
|
39 |
+
targets = text[0].expand(batch_size) # should be fill with [SOS] token
|
40 |
+
probs = batch_H.new_zeros((batch_size, num_steps, self.num_class), dtype=torch.float)
|
41 |
+
|
42 |
+
for i in range(num_steps):
|
43 |
+
char_embeddings = self.char_embeddings(targets)
|
44 |
+
hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings)
|
45 |
+
probs_step = self.generator(hidden[0])
|
46 |
+
probs[:, i, :] = probs_step
|
47 |
+
_, next_input = probs_step.max(1)
|
48 |
+
targets = next_input
|
49 |
+
|
50 |
+
return probs # batch_size x num_steps x num_class
|
51 |
+
|
52 |
+
|
53 |
+
class AttentionCell(nn.Module):
|
54 |
+
|
55 |
+
def __init__(self, input_size, hidden_size, num_embeddings):
|
56 |
+
super().__init__()
|
57 |
+
self.i2h = nn.Linear(input_size, hidden_size, bias=False)
|
58 |
+
self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias
|
59 |
+
self.score = nn.Linear(hidden_size, 1, bias=False)
|
60 |
+
self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
|
61 |
+
self.hidden_size = hidden_size
|
62 |
+
|
63 |
+
def forward(self, prev_hidden, batch_H, char_embeddings):
|
64 |
+
# [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
|
65 |
+
batch_H_proj = self.i2h(batch_H)
|
66 |
+
prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
|
67 |
+
e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
|
68 |
+
|
69 |
+
alpha = F.softmax(e, dim=1)
|
70 |
+
context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
|
71 |
+
concat_context = torch.cat([context, char_embeddings], 1) # batch_size x (num_channel + num_embedding)
|
72 |
+
cur_hidden = self.rnn(concat_context, prev_hidden)
|
73 |
+
return cur_hidden, alpha
|