anikde commited on
Commit
e2f99d5
·
1 Parent(s): 498450d

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +176 -0
  2. Dockerfile +45 -0
  3. IndicPhotoOCR/__init__.py +0 -0
  4. IndicPhotoOCR/detection/__init__.py +0 -0
  5. IndicPhotoOCR/detection/east_config.py +39 -0
  6. IndicPhotoOCR/detection/east_detector.py +87 -0
  7. IndicPhotoOCR/detection/east_locality_aware_nms.py +75 -0
  8. IndicPhotoOCR/detection/east_model.py +242 -0
  9. IndicPhotoOCR/detection/east_preprossing.py +681 -0
  10. IndicPhotoOCR/detection/east_utils.py +283 -0
  11. IndicPhotoOCR/ocr.py +154 -0
  12. IndicPhotoOCR/recognition/__init__.py +0 -0
  13. IndicPhotoOCR/recognition/parseq_recogniser.py +215 -0
  14. IndicPhotoOCR/script_identification/CLIP_identifier.py +201 -0
  15. IndicPhotoOCR/script_identification/__init__.py +0 -0
  16. IndicPhotoOCR/theme.py +43 -0
  17. IndicPhotoOCR/utils/strhub/__init__.py +2 -0
  18. IndicPhotoOCR/utils/strhub/data/__init__.py +1 -0
  19. IndicPhotoOCR/utils/strhub/data/aa_overrides.py +46 -0
  20. IndicPhotoOCR/utils/strhub/data/augment.py +112 -0
  21. IndicPhotoOCR/utils/strhub/data/dataset.py +148 -0
  22. IndicPhotoOCR/utils/strhub/data/module.py +157 -0
  23. IndicPhotoOCR/utils/strhub/data/utils.py +150 -0
  24. IndicPhotoOCR/utils/strhub/models/__init__.py +1 -0
  25. IndicPhotoOCR/utils/strhub/models/abinet/LICENSE +25 -0
  26. IndicPhotoOCR/utils/strhub/models/abinet/__init__.py +13 -0
  27. IndicPhotoOCR/utils/strhub/models/abinet/attention.py +100 -0
  28. IndicPhotoOCR/utils/strhub/models/abinet/backbone.py +24 -0
  29. IndicPhotoOCR/utils/strhub/models/abinet/model.py +31 -0
  30. IndicPhotoOCR/utils/strhub/models/abinet/model_abinet_iter.py +39 -0
  31. IndicPhotoOCR/utils/strhub/models/abinet/model_alignment.py +28 -0
  32. IndicPhotoOCR/utils/strhub/models/abinet/model_language.py +49 -0
  33. IndicPhotoOCR/utils/strhub/models/abinet/model_vision.py +45 -0
  34. IndicPhotoOCR/utils/strhub/models/abinet/resnet.py +72 -0
  35. IndicPhotoOCR/utils/strhub/models/abinet/system.py +215 -0
  36. IndicPhotoOCR/utils/strhub/models/abinet/transformer.py +198 -0
  37. IndicPhotoOCR/utils/strhub/models/base.py +221 -0
  38. IndicPhotoOCR/utils/strhub/models/crnn/LICENSE +21 -0
  39. IndicPhotoOCR/utils/strhub/models/crnn/__init__.py +13 -0
  40. IndicPhotoOCR/utils/strhub/models/crnn/model.py +62 -0
  41. IndicPhotoOCR/utils/strhub/models/crnn/system.py +56 -0
  42. IndicPhotoOCR/utils/strhub/models/modules.py +20 -0
  43. IndicPhotoOCR/utils/strhub/models/parseq/__init__.py +0 -0
  44. IndicPhotoOCR/utils/strhub/models/parseq/model.py +169 -0
  45. IndicPhotoOCR/utils/strhub/models/parseq/modules.py +176 -0
  46. IndicPhotoOCR/utils/strhub/models/parseq/system.py +200 -0
  47. IndicPhotoOCR/utils/strhub/models/trba/__init__.py +13 -0
  48. IndicPhotoOCR/utils/strhub/models/trba/feature_extraction.py +110 -0
  49. IndicPhotoOCR/utils/strhub/models/trba/model.py +55 -0
  50. IndicPhotoOCR/utils/strhub/models/trba/prediction.py +73 -0
.gitignore ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Output directories
2
+ outputs/
3
+ multirun/
4
+ ray_results/
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ # requirements/core.*.txt
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+ cover/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ .pybuilder/
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ # For a library or package, you might want to ignore these files since the code is
93
+ # intended to run in multiple environments; otherwise, check them in:
94
+ # .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104
+ __pypackages__/
105
+
106
+ # Celery stuff
107
+ celerybeat-schedule
108
+ celerybeat.pid
109
+
110
+ # SageMath parsed files
111
+ *.sage.py
112
+
113
+ # Environments
114
+ .env
115
+ .venv
116
+ env/
117
+ venv/
118
+ ENV/
119
+ env.bak/
120
+ venv.bak/
121
+ .python-version
122
+
123
+ # Spyder project settings
124
+ .spyderproject
125
+ .spyproject
126
+
127
+ # Rope project settings
128
+ .ropeproject
129
+
130
+ # mkdocs documentation
131
+ /site
132
+
133
+ # mypy
134
+ .mypy_cache/
135
+ .dmypy.json
136
+ dmypy.json
137
+
138
+ # Pyre type checker
139
+ .pyre/
140
+
141
+ # pytype static type analyzer
142
+ .pytype/
143
+
144
+ # Cython debug symbols
145
+ cython_debug/
146
+
147
+ # IDE
148
+ .idea/
149
+
150
+ ########## CUSTOM FOLDER ##############
151
+ README_original.md
152
+
153
+ results/
154
+ images
155
+ bharatSTR/East/tmp
156
+ bharatSTR/models
157
+ bharatSTR/images
158
+ __pycache__/
159
+ bharatSTR/
160
+
161
+ IndicPhotoOCR/detection/East
162
+ IndicPhotoOCR/recognition/models
163
+
164
+ IndicPhotoOCR/script_identification/images
165
+ IndicPhotoOCR/script_identification/models
166
+
167
+
168
+ build/
169
+ dist/
170
+ test.png
171
+ static/pics/IndicPhotoOCR.gif
172
+ input_image.jpg
173
+ output_image.png
174
+
175
+
176
+
Dockerfile ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use NVIDIA PyTorch as the base image
2
+ FROM nvcr.io/nvidia/pytorch:23.12-py3
3
+
4
+ # Install additional dependencies
5
+ RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6
6
+
7
+ # Set environment variables for Miniconda and Conda environment
8
+ ENV CONDA_DIR /opt/conda
9
+ ENV PATH $CONDA_DIR/bin:$PATH
10
+
11
+ # Install Miniconda
12
+ RUN apt-get update && apt-get install -y wget && \
13
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
14
+ bash Miniconda3-latest-Linux-x86_64.sh -b -p $CONDA_DIR && \
15
+ rm Miniconda3-latest-Linux-x86_64.sh
16
+
17
+ # Create a new Conda environment named "bocr" with Python 3.9
18
+ RUN conda create -n bocr python=3.9 -y
19
+
20
+ # Initialize conda
21
+ RUN conda init
22
+
23
+ # Reload the env configs
24
+ RUN source ~/.bashrc
25
+
26
+ # Make RUN commands use the bocr environment
27
+ SHELL ["conda", "run", "-n", "bocr", "/bin/bash", "-c"]
28
+
29
+ # # Set default shell to bash
30
+ # SHELL ["/bin/bash", "-c"]
31
+
32
+ # # Clone BharatOCR repository
33
+ # RUN git clone https://github.com/Bhashini-IITJ/BharatOCR.git && \
34
+ # git switch photoOCR && \
35
+ # cd IndicPhotoOCR && \
36
+ # python setup.py sdist bdist_wheel && \
37
+ # pip install ./dist/IndicPhotoOCR-1.1.0-py3-none-any.whl[cu118] --extra-index-url https://download.pytorch.org/whl/cu118
38
+
39
+ # # # Set default command to run BharatOCR
40
+ # CMD ["conda", "run", "-n", "bocr", "python", "-m", "IndicPhotoOCR.ocr"]
41
+
42
+
43
+ # cd IndicPhotoOCR
44
+ # sudo docker build -t indicphotoocr:latest .
45
+ # sudo docker run --gpus all --rm -it indicphotoocr:latest
IndicPhotoOCR/__init__.py ADDED
File without changes
IndicPhotoOCR/detection/__init__.py ADDED
File without changes
IndicPhotoOCR/detection/east_config.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # data-config
3
+ import numpy as np
4
+
5
+ train_data_path = './dataset/train/'
6
+ train_batch_size_per_gpu = 14 # 14
7
+ num_workers = 24 # 24
8
+ gpu_ids = [0] # [0,1,2,3]
9
+ gpu = 1 # 4
10
+ input_size = 512 # 预处理后归一化后图像尺寸
11
+ background_ratio = 3. / 8 # 纯背景样本比例
12
+ random_scale = np.array([0.5, 1, 2.0, 3.0]) # 提取多尺度图片信息
13
+ geometry = 'RBOX' # 选择使用几何特征图类型
14
+ max_image_large_side = 1280
15
+ max_text_size = 800
16
+ min_text_size = 10
17
+ min_crop_side_ratio = 0.1
18
+ means=[100, 100, 100]
19
+ pretrained = True # 是否加载基础网络的预训练模型
20
+ pretrained_basemodel_path = 'IndicPhotoOCR/detection/East/tmp/backbone_net/mobilenet_v2.pth.tar'
21
+ pre_lr = 1e-4 # 基础网络的初始学习率
22
+ lr = 1e-3 # 后面网络的初始学习率
23
+ decay_steps = 50 # decayed_learning_rate = learning_rate * decay_rate ^ (global_epoch / decay_steps)
24
+ decay_rate = 0.97
25
+ init_type = 'xavier' # 网络参数初始化方式
26
+ resume = True # 整体网络是否恢复原来保存的模型
27
+ checkpoint = 'IndicPhotoOCR/detection/East/tmp/epoch_990_checkpoint.pth.tar' # 指定具体路径及文件名
28
+ max_epochs = 1000 # 最大迭代epochs数
29
+ l2_weight_decay = 1e-6 # l2正则化惩罚项权重
30
+ print_freq = 10 # 每10个batch输出损失结果
31
+ save_eval_iteration = 50 # 每10个epoch保存一次模型,并做一次评价
32
+ save_model_path = './tmp/' # 模型保存路径
33
+ test_img_path = './dataset/full_set' # demo测试样本路径'./demo/test_img/',数据集测试为'./dataset/test/'
34
+ res_img_path = 'results' # demo结果存放路径'./demo/result_img/',数据集测试为 './dataset/test_result/'
35
+ write_images = True # 是否输出图像结果
36
+ score_map_thresh = 0.8 # 置信度阈值
37
+ box_thresh = 0.1 # 文本框中置信度平均值的阈值
38
+ nms_thres = 0.2 # 局部非极大抑制IOU阈值
39
+ compute_hmean_path = './dataset/test_compute_hmean/'
IndicPhotoOCR/detection/east_detector.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ import time
6
+ import warnings
7
+
8
+
9
+ import IndicPhotoOCR.detection.east_config as cfg
10
+ from IndicPhotoOCR.detection.east_utils import ModelManager
11
+ from IndicPhotoOCR.detection.east_model import East
12
+ import IndicPhotoOCR.detection.east_utils as utils
13
+
14
+ # Suppress warnings
15
+ warnings.filterwarnings("ignore")
16
+
17
+ class EASTdetector:
18
+ def __init__(self, model_name= "east", model_path=None):
19
+ self.model_path = model_path
20
+ # self.model_manager = ModelManager()
21
+ # self.model_manager.ensure_model(model_name)
22
+ # self.ensure_model(self.model_name)
23
+ # self.root_model_dir = "BharatSTR/bharatOCR/detection/East/tmp"
24
+
25
+ def detect(self, image_path, model_checkpoint, device):
26
+ # Load image
27
+ im = cv2.imread(image_path)
28
+ # im = cv2.imread(image_path)[:, :, ::-1]
29
+
30
+ # Initialize the EAST model and load checkpoint
31
+ model = East()
32
+ model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids)
33
+
34
+ # Load the model checkpoint with weights_only=True
35
+ checkpoint = torch.load(model_checkpoint, map_location=torch.device(device), weights_only=True)
36
+ model.load_state_dict(checkpoint['state_dict'])
37
+ model.eval()
38
+
39
+ # Resize image and convert to tensor format
40
+ im_resized, (ratio_h, ratio_w) = utils.resize_image(im)
41
+ im_resized = im_resized.astype(np.float32).transpose(2, 0, 1)
42
+ im_tensor = torch.from_numpy(im_resized).unsqueeze(0).cpu()
43
+
44
+ # Inference
45
+ timer = {'net': 0, 'restore': 0, 'nms': 0}
46
+ start = time.time()
47
+ score, geometry = model(im_tensor)
48
+ timer['net'] = time.time() - start
49
+
50
+ # Process output
51
+ score = score.permute(0, 2, 3, 1).data.cpu().numpy()
52
+ geometry = geometry.permute(0, 2, 3, 1).data.cpu().numpy()
53
+
54
+ # Detect boxes
55
+ boxes, timer = utils.detect(
56
+ score_map=score, geo_map=geometry, timer=timer,
57
+ score_map_thresh=cfg.score_map_thresh, box_thresh=cfg.box_thresh,
58
+ nms_thres=cfg.box_thresh
59
+ )
60
+ bbox_result_dict = {'detections': []}
61
+
62
+ # Parse detected boxes and adjust coordinates
63
+ if boxes is not None:
64
+ boxes = boxes[:, :8].reshape((-1, 4, 2))
65
+ boxes[:, :, 0] /= ratio_w
66
+ boxes[:, :, 1] /= ratio_h
67
+ for box in boxes:
68
+ box = utils.sort_poly(box.astype(np.int32))
69
+ if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5:
70
+ continue
71
+ bbox_result_dict['detections'].append([
72
+ [int(coord[0]), int(coord[1])] for coord in box
73
+ ])
74
+
75
+ return bbox_result_dict
76
+
77
+ # if __name__ == "__main__":
78
+ # import argparse
79
+ # parser = argparse.ArgumentParser(description='Text detection using EAST model')
80
+ # parser.add_argument('--image_path', type=str, required=True, help='Path to the input image')
81
+ # parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"')
82
+ # parser.add_argument('--model_checkpoint', type=str, required=True, help='Path to the model checkpoint file')
83
+ # args = parser.parse_args()
84
+
85
+ # # Run prediction and get results as dictionary
86
+ # detection_result = predict(args.image_path, args.device, args.model_checkpoint)
87
+ # print(detection_result)
IndicPhotoOCR/detection/east_locality_aware_nms.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from shapely.geometry import Polygon
4
+
5
+
6
+ def intersection(g, p):
7
+ g = Polygon(g[:8].reshape((4, 2)))
8
+ p = Polygon(p[:8].reshape((4, 2)))
9
+ if not g.is_valid or not p.is_valid:
10
+ return 0
11
+ inter = Polygon(g).intersection(Polygon(p)).area
12
+ union = g.area + p.area - inter
13
+ if union == 0:
14
+ return 0
15
+ else:
16
+ return inter/union
17
+
18
+
19
+ def weighted_merge(g, p):
20
+ # g[0]=min(g[0],p[0])
21
+ # g[1] = min(g[1], p[1])
22
+ # g[4] = max(g[4], p[4])
23
+ # g[5]= max(g[5],p[5])
24
+ #
25
+ # g[2] = max(g[2], p[2])
26
+ # g[3] = min(g[3], p[3])
27
+ # g[6] = min(g[6], p[6])
28
+ # g[7] = max(g[7], p[7])
29
+
30
+ g[:8] = (g[8] * g[:8] + p[8] * p[:8])/(g[8] + p[8])
31
+ g[8] = (g[8] + p[8])
32
+ return g
33
+
34
+
35
+ def standard_nms(S, thres):
36
+ order = np.argsort(S[:, 8])[::-1]
37
+ keep = []
38
+ while order.size > 0:
39
+ i = order[0]
40
+ keep.append(i)
41
+ ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
42
+
43
+ inds = np.where(ovr <= thres)[0]
44
+ order = order[inds+1]
45
+
46
+ return S[keep]
47
+
48
+
49
+ def nms_locality(polys, thres=0.3):
50
+ '''
51
+ locality aware nms of EAST
52
+ :param polys: a N*9 numpy array. first 8 coordinates, then prob
53
+ :return: boxes after nms
54
+ '''
55
+ S = []
56
+ p = None
57
+ for g in polys:
58
+ if p is not None and intersection(g, p) > thres:
59
+ p = weighted_merge(g, p)
60
+ else:
61
+ if p is not None:
62
+ S.append(p)
63
+ p = g
64
+ if p is not None:
65
+ S.append(p)
66
+
67
+ if len(S) == 0:
68
+ return np.array([])
69
+ return standard_nms(np.array(S), thres)
70
+
71
+
72
+ if __name__ == '__main__':
73
+ # 343,350,448,135,474,143,369,359
74
+ print(Polygon(np.array([[343, 350], [448, 135],
75
+ [474, 143], [369, 359]])).area)
IndicPhotoOCR/detection/east_model.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn as nn
3
+ import math
4
+ import torch
5
+
6
+
7
+ from IndicPhotoOCR.detection import east_config as cfg
8
+ from IndicPhotoOCR.detection import east_utils as utils
9
+
10
+
11
+ def conv_bn(inp, oup, stride):
12
+ return nn.Sequential(
13
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
14
+ nn.BatchNorm2d(oup),
15
+ nn.ReLU6(inplace=True)
16
+ )
17
+
18
+
19
+ class InvertedResidual(nn.Module):
20
+ def __init__(self, inp, oup, stride, expand_ratio):
21
+ super(InvertedResidual, self).__init__()
22
+ self.stride = stride
23
+ assert stride in [1, 2]
24
+
25
+ hidden_dim = round(inp * expand_ratio)
26
+ self.use_res_connect = self.stride == 1 and inp == oup
27
+
28
+ if expand_ratio == 1:
29
+ self.conv = nn.Sequential(
30
+ # dw
31
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
32
+ nn.BatchNorm2d(hidden_dim),
33
+ nn.ReLU6(inplace=True),
34
+ # pw-linear
35
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
36
+ nn.BatchNorm2d(oup),
37
+ )
38
+ else:
39
+ self.conv = nn.Sequential(
40
+ # pw
41
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
42
+ nn.BatchNorm2d(hidden_dim),
43
+ nn.ReLU6(inplace=True),
44
+ # dw
45
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
46
+ nn.BatchNorm2d(hidden_dim),
47
+ nn.ReLU6(inplace=True),
48
+ # pw-linear
49
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
50
+ nn.BatchNorm2d(oup),
51
+ )
52
+
53
+ def forward(self, x):
54
+ if self.use_res_connect:
55
+ return x + self.conv(x)
56
+ else:
57
+ return self.conv(x)
58
+
59
+
60
+ class MobileNetV2(nn.Module):
61
+ def __init__(self, width_mult=1.):
62
+ super(MobileNetV2, self).__init__()
63
+ block = InvertedResidual
64
+ input_channel = 32
65
+ last_channel = 1280
66
+ interverted_residual_setting = [
67
+ # t, c, n, s
68
+ [1, 16, 1, 1],
69
+ [6, 24, 2, 2],
70
+ [6, 32, 3, 2],
71
+ [6, 64, 4, 2],
72
+ [6, 96, 3, 1],
73
+ [6, 160, 3, 2],
74
+ # [6, 320, 1, 1],
75
+ ]
76
+
77
+ # building first layer
78
+ # assert input_size % 32 == 0
79
+ input_channel = int(input_channel * width_mult)
80
+ self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
81
+ self.features = [conv_bn(3, input_channel, 2)]
82
+ # building inverted residual blocks
83
+ for t, c, n, s in interverted_residual_setting:
84
+ output_channel = int(c * width_mult)
85
+ for i in range(n):
86
+ if i == 0:
87
+ self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
88
+ else:
89
+ self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
90
+ input_channel = output_channel
91
+
92
+ # make it nn.Sequential
93
+ self.features = nn.Sequential(*self.features)
94
+
95
+ self._initialize_weights()
96
+
97
+ def forward(self, x):
98
+ x = self.features(x)
99
+ # x = x.mean(3).mean(2)
100
+ # x = self.classifier(x)
101
+ return x
102
+
103
+ def _initialize_weights(self):
104
+ for m in self.modules():
105
+ if isinstance(m, nn.Conv2d):
106
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
107
+ m.weight.data.normal_(0, math.sqrt(2. / n))
108
+ if m.bias is not None:
109
+ m.bias.data.zero_()
110
+ elif isinstance(m, nn.BatchNorm2d):
111
+ m.weight.data.fill_(1)
112
+ m.bias.data.zero_()
113
+ elif isinstance(m, nn.Linear):
114
+ n = m.weight.size(1)
115
+ m.weight.data.normal_(0, 0.01)
116
+ m.bias.data.zero_()
117
+
118
+
119
+ def mobilenet(pretrained=True, **kwargs):
120
+ """
121
+ Constructs a ResNet-50 model.
122
+ Args:
123
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
124
+ """
125
+ model = MobileNetV2()
126
+ if pretrained:
127
+ model_dict = model.state_dict()
128
+ pretrained_dict = torch.load(cfg.pretrained_basemodel_path,map_location=torch.device('cpu'), weights_only=True)
129
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
130
+ model_dict.update(pretrained_dict)
131
+ model.load_state_dict(model_dict)
132
+ # state_dict = torch.load(cfg.pretrained_basemodel_path) # add map_location='cpu' if no gpu
133
+ # model.load_state_dict(state_dict)
134
+
135
+ return model
136
+
137
+
138
+ class East(nn.Module):
139
+ def __init__(self):
140
+ super(East, self).__init__()
141
+ self.mobilenet = mobilenet(True)
142
+ # self.si for stage i
143
+ self.s1 = nn.Sequential(*list(self.mobilenet.children())[0][0:4])
144
+ self.s2 = nn.Sequential(*list(self.mobilenet.children())[0][4:7])
145
+ self.s3 = nn.Sequential(*list(self.mobilenet.children())[0][7:14])
146
+ self.s4 = nn.Sequential(*list(self.mobilenet.children())[0][14:17])
147
+
148
+ self.conv1 = nn.Conv2d(160+96, 128, 1)
149
+ self.bn1 = nn.BatchNorm2d(128)
150
+ self.relu1 = nn.ReLU()
151
+
152
+ self.conv2 = nn.Conv2d(128, 128, 3, padding=1)
153
+ self.bn2 = nn.BatchNorm2d(128)
154
+ self.relu2 = nn.ReLU()
155
+
156
+ self.conv3 = nn.Conv2d(128+32, 64, 1)
157
+ self.bn3 = nn.BatchNorm2d(64)
158
+ self.relu3 = nn.ReLU()
159
+
160
+ self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
161
+ self.bn4 = nn.BatchNorm2d(64)
162
+ self.relu4 = nn.ReLU()
163
+
164
+ self.conv5 = nn.Conv2d(64+24, 64, 1)
165
+ self.bn5 = nn.BatchNorm2d(64)
166
+ self.relu5 = nn.ReLU()
167
+
168
+ self.conv6 = nn.Conv2d(64, 32, 3, padding=1)
169
+ self.bn6 = nn.BatchNorm2d(32)
170
+ self.relu6 = nn.ReLU()
171
+
172
+ self.conv7 = nn.Conv2d(32, 32, 3, padding=1)
173
+ self.bn7 = nn.BatchNorm2d(32)
174
+ self.relu7 = nn.ReLU()
175
+
176
+ self.conv8 = nn.Conv2d(32, 1, 1)
177
+ self.sigmoid1 = nn.Sigmoid()
178
+ self.conv9 = nn.Conv2d(32, 4, 1)
179
+ self.sigmoid2 = nn.Sigmoid()
180
+ self.conv10 = nn.Conv2d(32, 1, 1)
181
+ self.sigmoid3 = nn.Sigmoid()
182
+ self.unpool1 = nn.Upsample(scale_factor=2, mode='bilinear')
183
+ self.unpool2 = nn.Upsample(scale_factor=2, mode='bilinear')
184
+ self.unpool3 = nn.Upsample(scale_factor=2, mode='bilinear')
185
+
186
+ # utils.init_weights([self.conv1,self.conv2,self.conv3,self.conv4,
187
+ # self.conv5,self.conv6,self.conv7,self.conv8,
188
+ # self.conv9,self.conv10,self.bn1,self.bn2,
189
+ # self.bn3,self.bn4,self.bn5,self.bn6,self.bn7])
190
+
191
+ def forward(self, images):
192
+ images = utils.mean_image_subtraction(images)
193
+
194
+ f0 = self.s1(images)
195
+ f1 = self.s2(f0)
196
+ f2 = self.s3(f1)
197
+ f3 = self.s4(f2)
198
+
199
+ # _, f = self.mobilenet(images)
200
+ h = f3 # bs 2048 w/32 h/32
201
+ g = (self.unpool1(h)) # bs 2048 w/16 h/16
202
+ c = self.conv1(torch.cat((g, f2), 1))
203
+ c = self.bn1(c)
204
+ c = self.relu1(c)
205
+
206
+ h = self.conv2(c) # bs 128 w/16 h/16
207
+ h = self.bn2(h)
208
+ h = self.relu2(h)
209
+ g = self.unpool2(h) # bs 128 w/8 h/8
210
+ c = self.conv3(torch.cat((g, f1), 1))
211
+ c = self.bn3(c)
212
+ c = self.relu3(c)
213
+
214
+ h = self.conv4(c) # bs 64 w/8 h/8
215
+ h = self.bn4(h)
216
+ h = self.relu4(h)
217
+ g = self.unpool3(h) # bs 64 w/4 h/4
218
+ c = self.conv5(torch.cat((g, f0), 1))
219
+ c = self.bn5(c)
220
+ c = self.relu5(c)
221
+
222
+ h = self.conv6(c) # bs 32 w/4 h/4
223
+ h = self.bn6(h)
224
+ h = self.relu6(h)
225
+ g = self.conv7(h) # bs 32 w/4 h/4
226
+ g = self.bn7(g)
227
+ g = self.relu7(g)
228
+
229
+ F_score = self.conv8(g) # bs 1 w/4 h/4
230
+ F_score = self.sigmoid1(F_score)
231
+ geo_map = self.conv9(g)
232
+ geo_map = self.sigmoid2(geo_map) * 512
233
+ angle_map = self.conv10(g)
234
+ angle_map = self.sigmoid3(angle_map)
235
+ angle_map = (angle_map - 0.5) * math.pi / 2
236
+
237
+ F_geometry = torch.cat((geo_map, angle_map), 1) # bs 5 w/4 h/4
238
+
239
+ return F_score, F_geometry
240
+
241
+
242
+ model=East()
IndicPhotoOCR/detection/east_preprossing.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # coding:utf-8
3
+ import glob
4
+ import csv
5
+ import cv2
6
+ import os
7
+ import numpy as np
8
+ from shapely.geometry import Polygon
9
+
10
+
11
+ from IndicPhotoOCR.detection import east_config as cfg
12
+ from IndicPhotoOCR.detection import east_utils
13
+
14
+
15
+ def get_images(img_root):
16
+ files = []
17
+ for ext in ['jpg']:
18
+ files.extend(glob.glob(
19
+ os.path.join(img_root, '*.{}'.format(ext))))
20
+ # print(glob.glob(
21
+ # os.path.join(FLAGS.training_data_path, '*.{}'.format(ext))))
22
+ return files
23
+
24
+
25
+ def load_annoataion(p):
26
+ '''
27
+ load annotation from the text file
28
+ :param p:
29
+ :return:
30
+ '''
31
+ text_polys = []
32
+ text_tags = []
33
+ if not os.path.exists(p):
34
+ return np.array(text_polys, dtype=np.float32)
35
+ with open(p, 'r', encoding='UTF-8') as f:
36
+ reader = csv.reader(f)
37
+ for line in reader:
38
+ label = line[-1]
39
+ # strip BOM. \ufeff for python3, \xef\xbb\bf for python2
40
+ line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in line]
41
+
42
+ x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8]))
43
+ text_polys.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
44
+ # print(text_polys)
45
+ if label == '*' or label == '###':
46
+ text_tags.append(True)
47
+ else:
48
+ text_tags.append(False)
49
+ return np.array(text_polys, dtype=np.float32), np.array(text_tags, dtype=np.bool)
50
+
51
+
52
+ def polygon_area(poly):
53
+ '''
54
+ compute area of a polygon
55
+ :param poly:
56
+ :return:
57
+ '''
58
+ edge = [
59
+ (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
60
+ (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
61
+ (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
62
+ (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])
63
+ ]
64
+ return np.sum(edge) / 2.
65
+
66
+
67
+ def check_and_validate_polys(polys, tags, xxx_todo_changeme):
68
+ '''
69
+ check so that the text poly is in the same direction,
70
+ and also filter some invalid polygons
71
+ :param polys:
72
+ :param tags:
73
+ :return:
74
+ '''
75
+ (h, w) = xxx_todo_changeme
76
+ if polys.shape[0] == 0:
77
+ return polys
78
+ polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
79
+ polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
80
+
81
+ validated_polys = []
82
+ validated_tags = []
83
+
84
+ # 判断四边形的点时针方向,以及是否是有效四边形
85
+ for poly, tag in zip(polys, tags):
86
+ p_area = polygon_area(poly)
87
+ if abs(p_area) < 1:
88
+ # print poly
89
+ print('invalid poly')
90
+ continue
91
+ if p_area > 0:
92
+ print('poly in wrong direction')
93
+ poly = poly[(0, 3, 2, 1), :]
94
+ validated_polys.append(poly)
95
+ validated_tags.append(tag)
96
+ return np.array(validated_polys), np.array(validated_tags)
97
+
98
+
99
+ def crop_area(im, polys, tags, crop_background=False, max_tries=100):
100
+ '''
101
+ make random crop from the input image
102
+ :param im:
103
+ :param polys:
104
+ :param tags:
105
+ :param crop_background:
106
+ :param max_tries:
107
+ :return:
108
+ '''
109
+ h, w, _ = im.shape
110
+ pad_h = h // 10
111
+ pad_w = w // 10
112
+ h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
113
+ w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
114
+ for poly in polys:
115
+ poly = np.round(poly, decimals=0).astype(np.int32)
116
+ minx = np.min(poly[:, 0])
117
+ maxx = np.max(poly[:, 0])
118
+ w_array[minx + pad_w:maxx + pad_w] = 1
119
+ miny = np.min(poly[:, 1])
120
+ maxy = np.max(poly[:, 1])
121
+ h_array[miny + pad_h:maxy + pad_h] = 1
122
+ # ensure the cropped area not across a text,保证裁剪区域不能与文本交叉
123
+ h_axis = np.where(h_array == 0)[0]
124
+ w_axis = np.where(w_array == 0)[0]
125
+ if len(h_axis) == 0 or len(w_axis) == 0:
126
+ return im, polys, tags
127
+ for i in range(max_tries): # 试验50次
128
+ xx = np.random.choice(w_axis, size=2)
129
+ xmin = np.min(xx) - pad_w
130
+ xmax = np.max(xx) - pad_w
131
+ xmin = np.clip(xmin, 0, w - 1)
132
+ xmax = np.clip(xmax, 0, w - 1)
133
+ yy = np.random.choice(h_axis, size=2)
134
+ ymin = np.min(yy) - pad_h
135
+ ymax = np.max(yy) - pad_h
136
+ ymin = np.clip(ymin, 0, h - 1)
137
+ ymax = np.clip(ymax, 0, h - 1)
138
+ if xmax - xmin < cfg.min_crop_side_ratio * w or ymax - ymin < cfg.min_crop_side_ratio * h:
139
+ # area too small
140
+ continue
141
+ if polys.shape[0] != 0:
142
+ poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
143
+ & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
144
+ selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
145
+ else:
146
+ selected_polys = []
147
+ if len(selected_polys) == 0:
148
+ # no text in this area
149
+ if crop_background:
150
+ return im[ymin:ymax + 1, xmin:xmax + 1, :], polys[selected_polys], tags[selected_polys]
151
+ else:
152
+ continue
153
+ im = im[ymin:ymax + 1, xmin:xmax + 1, :]
154
+ polys = polys[selected_polys]
155
+ tags = tags[selected_polys]
156
+ polys[:, :, 0] -= xmin
157
+ polys[:, :, 1] -= ymin
158
+ return im, polys, tags
159
+
160
+ return im, polys, tags
161
+
162
+
163
+ def shrink_poly(poly, r):
164
+ '''
165
+ fit a poly inside the origin poly, maybe bugs here...
166
+ used for generate the score map
167
+ :param poly: the text poly
168
+ :param r: r in the paper
169
+ :return: the shrinked poly
170
+ '''
171
+ # shrink ratio
172
+ R = 0.3
173
+ # find the longer pair
174
+ if np.linalg.norm(poly[0] - poly[1]) + np.linalg.norm(poly[2] - poly[3]) > \
175
+ np.linalg.norm(poly[0] - poly[3]) + np.linalg.norm(poly[1] - poly[2]):
176
+ # first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
177
+ ## p0, p1
178
+ theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
179
+ poly[0][0] += R * r[0] * np.cos(theta)
180
+ poly[0][1] += R * r[0] * np.sin(theta)
181
+ poly[1][0] -= R * r[1] * np.cos(theta)
182
+ poly[1][1] -= R * r[1] * np.sin(theta)
183
+ ## p2, p3
184
+ theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
185
+ poly[3][0] += R * r[3] * np.cos(theta)
186
+ poly[3][1] += R * r[3] * np.sin(theta)
187
+ poly[2][0] -= R * r[2] * np.cos(theta)
188
+ poly[2][1] -= R * r[2] * np.sin(theta)
189
+ ## p0, p3
190
+ theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
191
+ poly[0][0] += R * r[0] * np.sin(theta)
192
+ poly[0][1] += R * r[0] * np.cos(theta)
193
+ poly[3][0] -= R * r[3] * np.sin(theta)
194
+ poly[3][1] -= R * r[3] * np.cos(theta)
195
+ ## p1, p2
196
+ theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
197
+ poly[1][0] += R * r[1] * np.sin(theta)
198
+ poly[1][1] += R * r[1] * np.cos(theta)
199
+ poly[2][0] -= R * r[2] * np.sin(theta)
200
+ poly[2][1] -= R * r[2] * np.cos(theta)
201
+ else:
202
+ ## p0, p3
203
+ # print poly
204
+ theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
205
+ poly[0][0] += R * r[0] * np.sin(theta)
206
+ poly[0][1] += R * r[0] * np.cos(theta)
207
+ poly[3][0] -= R * r[3] * np.sin(theta)
208
+ poly[3][1] -= R * r[3] * np.cos(theta)
209
+ ## p1, p2
210
+ theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
211
+ poly[1][0] += R * r[1] * np.sin(theta)
212
+ poly[1][1] += R * r[1] * np.cos(theta)
213
+ poly[2][0] -= R * r[2] * np.sin(theta)
214
+ poly[2][1] -= R * r[2] * np.cos(theta)
215
+ ## p0, p1
216
+ theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
217
+ poly[0][0] += R * r[0] * np.cos(theta)
218
+ poly[0][1] += R * r[0] * np.sin(theta)
219
+ poly[1][0] -= R * r[1] * np.cos(theta)
220
+ poly[1][1] -= R * r[1] * np.sin(theta)
221
+ ## p2, p3
222
+ theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
223
+ poly[3][0] += R * r[3] * np.cos(theta)
224
+ poly[3][1] += R * r[3] * np.sin(theta)
225
+ poly[2][0] -= R * r[2] * np.cos(theta)
226
+ poly[2][1] -= R * r[2] * np.sin(theta)
227
+ return poly
228
+
229
+
230
+ # def point_dist_to_line(p1, p2, p3):
231
+ # # compute the distance from p3 to p1-p2
232
+ # return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1)
233
+
234
+
235
+ # 点p3到直线p12的距离
236
+ def point_dist_to_line(p1, p2, p3):
237
+ # compute the distance from p3 to p1-p2
238
+ # return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1)
239
+ a = np.linalg.norm(p1 - p2)
240
+ b = np.linalg.norm(p2 - p3)
241
+ c = np.linalg.norm(p3 - p1)
242
+ s = (a + b + c) / 2.0
243
+ area = np.abs((s * (s - a) * (s - b) * (s - c))) ** 0.5
244
+ if a < 1.0:
245
+ return (b + c) / 2.0
246
+ return 2 * area / a
247
+
248
+
249
+ def fit_line(p1, p2):
250
+ # fit a line ax+by+c = 0
251
+ if p1[0] == p1[1]:
252
+ return [1., 0., -p1[0]]
253
+ else:
254
+ [k, b] = np.polyfit(p1, p2, deg=1)
255
+ return [k, -1., b]
256
+
257
+
258
+ def line_cross_point(line1, line2):
259
+ # line1 0= ax+by+c, compute the cross point of line1 and line2
260
+ if line1[0] != 0 and line1[0] == line2[0]:
261
+ print('Cross point does not exist')
262
+ return None
263
+ if line1[0] == 0 and line2[0] == 0:
264
+ print('Cross point does not exist')
265
+ return None
266
+ if line1[1] == 0:
267
+ x = -line1[2]
268
+ y = line2[0] * x + line2[2]
269
+ elif line2[1] == 0:
270
+ x = -line2[2]
271
+ y = line1[0] * x + line1[2]
272
+ else:
273
+ k1, _, b1 = line1
274
+ k2, _, b2 = line2
275
+ x = -(b1 - b2) / (k1 - k2)
276
+ y = k1 * x + b1
277
+ return np.array([x, y], dtype=np.float32)
278
+
279
+
280
+ def line_verticle(line, point):
281
+ # get the verticle line from line across point
282
+ if line[1] == 0:
283
+ verticle = [0, -1, point[1]]
284
+ else:
285
+ if line[0] == 0:
286
+ verticle = [1, 0, -point[0]]
287
+ else:
288
+ verticle = [-1. / line[0], -1, point[1] - (-1 / line[0] * point[0])]
289
+ return verticle
290
+
291
+
292
+ def rectangle_from_parallelogram(poly):
293
+ '''
294
+ fit a rectangle from a parallelogram
295
+ :param poly:
296
+ :return:
297
+ '''
298
+ p0, p1, p2, p3 = poly
299
+ angle_p0 = np.arccos(np.dot(p1 - p0, p3 - p0) / (np.linalg.norm(p0 - p1) * np.linalg.norm(p3 - p0)))
300
+ if angle_p0 < 0.5 * np.pi:
301
+ if np.linalg.norm(p0 - p1) > np.linalg.norm(p0 - p3):
302
+ # p0 and p2
303
+ ## p0
304
+ p2p3 = fit_line([p2[0], p3[0]], [p2[1], p3[1]])
305
+ p2p3_verticle = line_verticle(p2p3, p0)
306
+
307
+ new_p3 = line_cross_point(p2p3, p2p3_verticle)
308
+ ## p2
309
+ p0p1 = fit_line([p0[0], p1[0]], [p0[1], p1[1]])
310
+ p0p1_verticle = line_verticle(p0p1, p2)
311
+
312
+ new_p1 = line_cross_point(p0p1, p0p1_verticle)
313
+ return np.array([p0, new_p1, p2, new_p3], dtype=np.float32)
314
+ else:
315
+ p1p2 = fit_line([p1[0], p2[0]], [p1[1], p2[1]])
316
+ p1p2_verticle = line_verticle(p1p2, p0)
317
+
318
+ new_p1 = line_cross_point(p1p2, p1p2_verticle)
319
+ p0p3 = fit_line([p0[0], p3[0]], [p0[1], p3[1]])
320
+ p0p3_verticle = line_verticle(p0p3, p2)
321
+
322
+ new_p3 = line_cross_point(p0p3, p0p3_verticle)
323
+ return np.array([p0, new_p1, p2, new_p3], dtype=np.float32)
324
+ else:
325
+ if np.linalg.norm(p0 - p1) > np.linalg.norm(p0 - p3):
326
+ # p1 and p3
327
+ ## p1
328
+ p2p3 = fit_line([p2[0], p3[0]], [p2[1], p3[1]])
329
+ p2p3_verticle = line_verticle(p2p3, p1)
330
+
331
+ new_p2 = line_cross_point(p2p3, p2p3_verticle)
332
+ ## p3
333
+ p0p1 = fit_line([p0[0], p1[0]], [p0[1], p1[1]])
334
+ p0p1_verticle = line_verticle(p0p1, p3)
335
+
336
+ new_p0 = line_cross_point(p0p1, p0p1_verticle)
337
+ return np.array([new_p0, p1, new_p2, p3], dtype=np.float32)
338
+ else:
339
+ p0p3 = fit_line([p0[0], p3[0]], [p0[1], p3[1]])
340
+ p0p3_verticle = line_verticle(p0p3, p1)
341
+
342
+ new_p0 = line_cross_point(p0p3, p0p3_verticle)
343
+ p1p2 = fit_line([p1[0], p2[0]], [p1[1], p2[1]])
344
+ p1p2_verticle = line_verticle(p1p2, p3)
345
+
346
+ new_p2 = line_cross_point(p1p2, p1p2_verticle)
347
+ return np.array([new_p0, p1, new_p2, p3], dtype=np.float32)
348
+
349
+
350
+ def sort_rectangle(poly):
351
+ # sort the four coordinates of the polygon, points in poly should be sorted clockwise
352
+ # First find the lowest point
353
+ p_lowest = np.argmax(poly[:, 1])
354
+ if np.count_nonzero(poly[:, 1] == poly[p_lowest, 1]) == 2:
355
+ # 底边平行于X轴, 那么p0为左上角 - if the bottom line is parallel to x-axis, then p0 must be the upper-left corner
356
+ p0_index = np.argmin(np.sum(poly, axis=1))
357
+ p1_index = (p0_index + 1) % 4
358
+ p2_index = (p0_index + 2) % 4
359
+ p3_index = (p0_index + 3) % 4
360
+ return poly[[p0_index, p1_index, p2_index, p3_index]], 0.
361
+ else:
362
+ # 找到最低点右边的点 - find the point that sits right to the lowest point
363
+ p_lowest_right = (p_lowest - 1) % 4
364
+ p_lowest_left = (p_lowest + 1) % 4
365
+ angle = np.arctan(
366
+ -(poly[p_lowest][1] - poly[p_lowest_right][1]) / (poly[p_lowest][0] - poly[p_lowest_right][0]))
367
+ # assert angle > 0
368
+ if angle <= 0:
369
+ print(angle, poly[p_lowest], poly[p_lowest_right])
370
+ if angle / np.pi * 180 > 45:
371
+ # 这个点为p2 - this point is p2
372
+ p2_index = p_lowest
373
+ p1_index = (p2_index - 1) % 4
374
+ p0_index = (p2_index - 2) % 4
375
+ p3_index = (p2_index + 1) % 4
376
+ return poly[[p0_index, p1_index, p2_index, p3_index]], -(np.pi / 2 - angle)
377
+ else:
378
+ # 这个点为p3 - this point is p3
379
+ p3_index = p_lowest
380
+ p0_index = (p3_index + 1) % 4
381
+ p1_index = (p3_index + 2) % 4
382
+ p2_index = (p3_index + 3) % 4
383
+ return poly[[p0_index, p1_index, p2_index, p3_index]], angle
384
+
385
+
386
+ def restore_rectangle_rbox(origin, geometry):
387
+ d = geometry[:, :4]
388
+ angle = geometry[:, 4]
389
+ # for angle > 0
390
+ origin_0 = origin[angle >= 0]
391
+ d_0 = d[angle >= 0]
392
+ angle_0 = angle[angle >= 0]
393
+ if origin_0.shape[0] > 0:
394
+ p = np.array([np.zeros(d_0.shape[0]), -d_0[:, 0] - d_0[:, 2],
395
+ d_0[:, 1] + d_0[:, 3], -d_0[:, 0] - d_0[:, 2],
396
+ d_0[:, 1] + d_0[:, 3], np.zeros(d_0.shape[0]),
397
+ np.zeros(d_0.shape[0]), np.zeros(d_0.shape[0]),
398
+ d_0[:, 3], -d_0[:, 2]])
399
+ p = p.transpose((1, 0)).reshape((-1, 5, 2)) # N*5*2
400
+
401
+ rotate_matrix_x = np.array([np.cos(angle_0), np.sin(angle_0)]).transpose((1, 0))
402
+ rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) # N*5*2
403
+
404
+ rotate_matrix_y = np.array([-np.sin(angle_0), np.cos(angle_0)]).transpose((1, 0))
405
+ rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1))
406
+
407
+ p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis] # N*5*1
408
+ p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis] # N*5*1
409
+
410
+ p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2) # N*5*2
411
+
412
+ p3_in_origin = origin_0 - p_rotate[:, 4, :]
413
+ new_p0 = p_rotate[:, 0, :] + p3_in_origin # N*2
414
+ new_p1 = p_rotate[:, 1, :] + p3_in_origin
415
+ new_p2 = p_rotate[:, 2, :] + p3_in_origin
416
+ new_p3 = p_rotate[:, 3, :] + p3_in_origin
417
+
418
+ new_p_0 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :],
419
+ new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1) # N*4*2
420
+ else:
421
+ new_p_0 = np.zeros((0, 4, 2))
422
+ # for angle < 0
423
+ origin_1 = origin[angle < 0]
424
+ d_1 = d[angle < 0]
425
+ angle_1 = angle[angle < 0]
426
+ if origin_1.shape[0] > 0:
427
+ p = np.array([-d_1[:, 1] - d_1[:, 3], -d_1[:, 0] - d_1[:, 2],
428
+ np.zeros(d_1.shape[0]), -d_1[:, 0] - d_1[:, 2],
429
+ np.zeros(d_1.shape[0]), np.zeros(d_1.shape[0]),
430
+ -d_1[:, 1] - d_1[:, 3], np.zeros(d_1.shape[0]),
431
+ -d_1[:, 1], -d_1[:, 2]])
432
+ p = p.transpose((1, 0)).reshape((-1, 5, 2)) # N*5*2
433
+
434
+ rotate_matrix_x = np.array([np.cos(-angle_1), -np.sin(-angle_1)]).transpose((1, 0))
435
+ rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) # N*5*2
436
+
437
+ rotate_matrix_y = np.array([np.sin(-angle_1), np.cos(-angle_1)]).transpose((1, 0))
438
+ rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1))
439
+
440
+ p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis] # N*5*1
441
+ p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis] # N*5*1
442
+
443
+ p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2) # N*5*2
444
+
445
+ p3_in_origin = origin_1 - p_rotate[:, 4, :]
446
+ new_p0 = p_rotate[:, 0, :] + p3_in_origin # N*2
447
+ new_p1 = p_rotate[:, 1, :] + p3_in_origin
448
+ new_p2 = p_rotate[:, 2, :] + p3_in_origin
449
+ new_p3 = p_rotate[:, 3, :] + p3_in_origin
450
+
451
+ new_p_1 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :],
452
+ new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1) # N*4*2
453
+ else:
454
+ new_p_1 = np.zeros((0, 4, 2))
455
+ return np.concatenate([new_p_0, new_p_1])
456
+
457
+
458
+ def restore_rectangle(origin, geometry):
459
+ return restore_rectangle_rbox(origin, geometry)
460
+
461
+
462
+ def generate_rbox(im_size, polys, tags):
463
+ h, w = im_size
464
+ poly_mask = np.zeros((h, w), dtype=np.uint8)
465
+ score_map = np.zeros((h, w), dtype=np.uint8)
466
+ geo_map = np.zeros((h, w, 5), dtype=np.float32)
467
+ # mask used during traning, to ignore some hard areas,用于忽略那些过小的文本
468
+ training_mask = np.ones((h, w), dtype=np.uint8)
469
+ for poly_idx, poly_tag in enumerate(zip(polys, tags)):
470
+ poly = poly_tag[0]
471
+ tag = poly_tag[1]
472
+
473
+ # 对每个顶点,找到经过他的两条边中较短的那条
474
+ r = [None, None, None, None]
475
+ for i in range(4):
476
+ r[i] = min(np.linalg.norm(poly[i] - poly[(i + 1) % 4]),
477
+ np.linalg.norm(poly[i] - poly[(i - 1) % 4]))
478
+ # score map
479
+ # 放缩边框为之前的0.3倍,并对边框对应score图中的位置进行填充
480
+ shrinked_poly = shrink_poly(poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
481
+ cv2.fillPoly(score_map, shrinked_poly, 1)
482
+ cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
483
+ # if the poly is too small, then ignore it during training
484
+ # 如果文本框标签太小或者txt中没具体标记是什么内容,即*或者###,则加掩模,训练时忽略该部分
485
+ poly_h = min(np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2]))
486
+ poly_w = min(np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3]))
487
+ if min(poly_h, poly_w) < cfg.min_text_size:
488
+ cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
489
+ if tag:
490
+ cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
491
+
492
+ # 当前新加入的文本框区域像素点
493
+ xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
494
+ # if geometry == 'RBOX':
495
+ # 对任意两个顶点的组合生成一个平行四边形 - generate a parallelogram for any combination of two vertices
496
+ fitted_parallelograms = []
497
+ for i in range(4):
498
+ # 选中p0和p1的连线边,生成两个平行四边形
499
+ p0 = poly[i]
500
+ p1 = poly[(i + 1) % 4]
501
+ p2 = poly[(i + 2) % 4]
502
+ p3 = poly[(i + 3) % 4]
503
+ # 拟合ax+by+c=0
504
+ edge = fit_line([p0[0], p1[0]], [p0[1], p1[1]])
505
+ backward_edge = fit_line([p0[0], p3[0]], [p0[1], p3[1]])
506
+ forward_edge = fit_line([p1[0], p2[0]], [p1[1], p2[1]])
507
+ # 通过另外两个点距离edge的距离,来决定edge对应的平行线应该过p2还是p3
508
+ if point_dist_to_line(p0, p1, p2) > point_dist_to_line(p0, p1, p3):
509
+ # 平行线经过p2 - parallel lines through p2
510
+ if edge[1] == 0:
511
+ edge_opposite = [1, 0, -p2[0]]
512
+ else:
513
+ edge_opposite = [edge[0], -1, p2[1] - edge[0] * p2[0]]
514
+ else:
515
+ # 经过p3 - after p3
516
+ if edge[1] == 0:
517
+ edge_opposite = [1, 0, -p3[0]]
518
+ else:
519
+ edge_opposite = [edge[0], -1, p3[1] - edge[0] * p3[0]]
520
+ # move forward edge
521
+ new_p0 = p0
522
+ new_p1 = p1
523
+ new_p2 = p2
524
+ new_p3 = p3
525
+ new_p2 = line_cross_point(forward_edge, edge_opposite)
526
+ if point_dist_to_line(p1, new_p2, p0) > point_dist_to_line(p1, new_p2, p3):
527
+ # across p0
528
+ if forward_edge[1] == 0:
529
+ forward_opposite = [1, 0, -p0[0]]
530
+ else:
531
+ forward_opposite = [forward_edge[0], -1, p0[1] - forward_edge[0] * p0[0]]
532
+ else:
533
+ # across p3
534
+ if forward_edge[1] == 0:
535
+ forward_opposite = [1, 0, -p3[0]]
536
+ else:
537
+ forward_opposite = [forward_edge[0], -1, p3[1] - forward_edge[0] * p3[0]]
538
+ new_p0 = line_cross_point(forward_opposite, edge)
539
+ new_p3 = line_cross_point(forward_opposite, edge_opposite)
540
+ fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0])
541
+ # or move backward edge
542
+ new_p0 = p0
543
+ new_p1 = p1
544
+ new_p2 = p2
545
+ new_p3 = p3
546
+ new_p3 = line_cross_point(backward_edge, edge_opposite)
547
+ if point_dist_to_line(p0, p3, p1) > point_dist_to_line(p0, p3, p2):
548
+ # across p1
549
+ if backward_edge[1] == 0:
550
+ backward_opposite = [1, 0, -p1[0]]
551
+ else:
552
+ backward_opposite = [backward_edge[0], -1, p1[1] - backward_edge[0] * p1[0]]
553
+ else:
554
+ # across p2
555
+ if backward_edge[1] == 0:
556
+ backward_opposite = [1, 0, -p2[0]]
557
+ else:
558
+ backward_opposite = [backward_edge[0], -1, p2[1] - backward_edge[0] * p2[0]]
559
+ new_p1 = line_cross_point(backward_opposite, edge)
560
+ new_p2 = line_cross_point(backward_opposite, edge_opposite)
561
+ fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0])
562
+
563
+ # 选定面积最小的平行四边形
564
+ areas = [Polygon(t).area for t in fitted_parallelograms]
565
+ parallelogram = np.array(fitted_parallelograms[np.argmin(areas)][:-1], dtype=np.float32)
566
+ # sort thie polygon
567
+ parallelogram_coord_sum = np.sum(parallelogram, axis=1)
568
+ min_coord_idx = np.argmin(parallelogram_coord_sum)
569
+ parallelogram = parallelogram[
570
+ [min_coord_idx, (min_coord_idx + 1) % 4, (min_coord_idx + 2) % 4, (min_coord_idx + 3) % 4]]
571
+
572
+ # 得到外包矩形即旋转角
573
+ rectange = rectangle_from_parallelogram(parallelogram)
574
+ rectange, rotate_angle = sort_rectangle(rectange)
575
+
576
+ p0_rect, p1_rect, p2_rect, p3_rect = rectange
577
+ # 对当前新加入的文本框区域像素点,根据其到矩形四边的距离修改geo_map
578
+ for y, x in xy_in_poly:
579
+ point = np.array([x, y], dtype=np.float32)
580
+ # top
581
+ geo_map[y, x, 0] = point_dist_to_line(p0_rect, p1_rect, point)
582
+ # right
583
+ geo_map[y, x, 1] = point_dist_to_line(p1_rect, p2_rect, point)
584
+ # down
585
+ geo_map[y, x, 2] = point_dist_to_line(p2_rect, p3_rect, point)
586
+ # left
587
+ geo_map[y, x, 3] = point_dist_to_line(p3_rect, p0_rect, point)
588
+ # angle
589
+ geo_map[y, x, 4] = rotate_angle
590
+ return score_map, geo_map, training_mask
591
+
592
+
593
+ def generator(index,
594
+ input_size=512,
595
+ background_ratio=3. / 8, # 纯背景样本比例
596
+ random_scale=np.array([0.5, 1, 2.0, 3.0]), # 提取多尺度图片信息
597
+ image_list=None):
598
+ try:
599
+ im_fn = image_list[index]
600
+ im = cv2.imread(im_fn)
601
+ if im is None:
602
+ print("can't find image")
603
+ return None, None, None, None, None
604
+ h, w, _ = im.shape
605
+ # 所以要把gt去掉
606
+ txt_fn = im_fn.replace(os.path.basename(im_fn).split('.')[1], 'txt')
607
+ if not os.path.exists(txt_fn):
608
+ print('text file {} does not exists'.format(txt_fn))
609
+ return None, None, None, None, None
610
+ # 加载标注框信息
611
+ text_polys, text_tags = load_annoataion(txt_fn)
612
+
613
+ text_polys, text_tags = check_and_validate_polys(text_polys, text_tags, (h, w))
614
+
615
+ # random scale this image,随机选择一种尺度
616
+ rd_scale = np.random.choice(random_scale)
617
+ im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
618
+ text_polys *= rd_scale
619
+
620
+ # random crop a area from image,3/8���选中的概率,裁剪纯背景的图片
621
+ if np.random.rand() < background_ratio:
622
+ # crop background
623
+ im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=True)
624
+ if text_polys.shape[0] > 0:
625
+ # print("cannot find background")
626
+ return None, None, None, None, None
627
+ # pad and resize image
628
+ new_h, new_w, _ = im.shape
629
+ max_h_w_i = np.max([new_h, new_w, input_size])
630
+ im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
631
+ im_padded[:new_h, :new_w, :] = im.copy()
632
+ # 将裁剪后图片扩充成512*512的图片
633
+ im = cv2.resize(im_padded, dsize=(input_size, input_size))
634
+ score_map = np.zeros((input_size, input_size), dtype=np.uint8)
635
+ geo_map_channels = 5 if cfg.geometry == 'RBOX' else 8
636
+ geo_map = np.zeros((input_size, input_size, geo_map_channels), dtype=np.float32)
637
+ training_mask = np.ones((input_size, input_size), dtype=np.uint8)
638
+ else:
639
+ # 5 / 8的选中的概率,裁剪含文本信息的图片
640
+ im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=False)
641
+ if text_polys.shape[0] == 0:
642
+ # print("cannot find txt ground")
643
+ return None, None, None, None, None
644
+ h, w, _ = im.shape
645
+ # pad the image to the training input size or the longer side of image
646
+ new_h, new_w, _ = im.shape
647
+ max_h_w_i = np.max([new_h, new_w, input_size])
648
+ im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
649
+ im_padded[:new_h, :new_w, :] = im.copy()
650
+ im = im_padded
651
+ # resize the image to input size
652
+ # 填充,resize图像至设定尺寸
653
+ new_h, new_w, _ = im.shape
654
+ resize_h = input_size
655
+ resize_w = input_size
656
+ im = cv2.resize(im, dsize=(resize_w, resize_h))
657
+ # 将文本框坐标标签等比例修改
658
+ resize_ratio_3_x = resize_w / float(new_w)
659
+ resize_ratio_3_y = resize_h / float(new_h)
660
+ text_polys[:, :, 0] *= resize_ratio_3_x
661
+ text_polys[:, :, 1] *= resize_ratio_3_y
662
+ new_h, new_w, _ = im.shape
663
+ score_map, geo_map, training_mask = generate_rbox((new_h, new_w), text_polys, text_tags)
664
+
665
+ # 将一个样本的样本内容和标签信息append
666
+ images = im[:,:,::-1].astype(np.float32)
667
+ # 文件名加入列表
668
+ image_fns = im_fn
669
+ # 512*512取提取四分之一行列
670
+ score_maps = score_map[::4, ::4, np.newaxis].astype(np.float32)
671
+ geo_maps = geo_map[::4, ::4, :].astype(np.float32)
672
+ training_masks = training_mask[::4, ::4, np.newaxis].astype(np.float32)
673
+ # 符合一个样本之后输出
674
+ return images, image_fns, score_maps, geo_maps, training_masks
675
+
676
+ except Exception as e:
677
+ import traceback
678
+ traceback.print_exc()
679
+
680
+ # print("Exception is exist!")
681
+ return None, None, None, None, None
IndicPhotoOCR/detection/east_utils.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import os
4
+ from torch.nn import init
5
+ import cv2
6
+ import numpy as np
7
+ import time
8
+ import requests
9
+
10
+ from IndicPhotoOCR.detection import east_config as cfg
11
+ from IndicPhotoOCR.detection import east_preprossing as preprossing
12
+ from IndicPhotoOCR.detection import east_locality_aware_nms as locality_aware_nms
13
+
14
+
15
+
16
+ # Example usage:
17
+ model_info = {
18
+ "east": {
19
+ "paths": [ cfg.checkpoint, cfg.pretrained_basemodel_path],
20
+ "urls" : ["https://github.com/anikde/STocr/releases/download/e0.1.0/epoch_990_checkpoint.pth.tar", "https://github.com/anikde/STocr/releases/download/e0.1.0/mobilenet_v2.pth.tar"]
21
+ },
22
+ }
23
+
24
+ class ModelManager:
25
+ def __init__(self):
26
+ # self.root_model_dir = "bharatOCR/detection/"
27
+ pass
28
+
29
+ def download_model(self, url, path):
30
+ response = requests.get(url, stream=True)
31
+ if response.status_code == 200:
32
+ with open(path, 'wb') as f:
33
+ for chunk in response.iter_content(chunk_size=8192):
34
+ if chunk: # Filter out keep-alive chunks
35
+ f.write(chunk)
36
+ print(f"Downloaded: {path}")
37
+ else:
38
+ print(f"Failed to download from {url}")
39
+
40
+ def ensure_model(self, model_name):
41
+ model_paths = model_info[model_name]["paths"] # Changed to handle multiple paths
42
+ urls = model_info[model_name]["urls"] # Changed to handle multiple URLs
43
+
44
+
45
+ for model_path, url in zip(model_paths, urls):
46
+ # full_model_path = os.path.join(self.root_model_dir, model_path)
47
+
48
+ # Ensure the model path directory exists
49
+ os.makedirs(os.path.dirname(os.path.join(*cfg.pretrained_basemodel_path.split("/"))), exist_ok=True)
50
+
51
+ if not os.path.exists(model_path):
52
+ print(f"Model not found locally. Downloading {model_name} from {url}...")
53
+ self.download_model(url, model_path)
54
+ else:
55
+ print(f"Model already exists at {model_path}. No need to download.")
56
+
57
+
58
+
59
+ # # Initialize ModelManager and ensure Hindi models are downloaded
60
+ model_manager = ModelManager()
61
+ model_manager.ensure_model("east")
62
+
63
+
64
+
65
+ def init_weights(m_list, init_type=cfg.init_type, gain=0.02):
66
+ print("EAST <==> Prepare <==> Init Network'{}' <==> Begin".format(cfg.init_type))
67
+ # this will apply to each layer
68
+ for m in m_list:
69
+ classname = m.__class__.__name__
70
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
71
+ if init_type == 'normal':
72
+ init.normal_(m.weight.data, 0.0, gain)
73
+ elif init_type == 'xavier':
74
+ init.xavier_normal_(m.weight.data, gain=gain)
75
+ elif init_type == 'kaiming':
76
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # good for relu
77
+ elif init_type == 'orthogonal':
78
+ init.orthogonal_(m.weight.data, gain=gain)
79
+ else:
80
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
81
+
82
+ if hasattr(m, 'bias') and m.bias is not None:
83
+ init.constant_(m.bias.data, 0.0)
84
+ elif classname.find('BatchNorm2d') != -1:
85
+ init.normal_(m.weight.data, 1.0, gain)
86
+ init.constant_(m.bias.data, 0.0)
87
+
88
+ print("EAST <==> Prepare <==> Init Network'{}' <==> Done".format(cfg.init_type))
89
+
90
+
91
+ def Loading_checkpoint(model, optimizer, scheduler, filename='checkpoint.pth.tar'):
92
+ """[summary]
93
+ [description]
94
+ Arguments:
95
+ state {[type]} -- [description] a dict describe some params
96
+ Keyword Arguments:
97
+ filename {str} -- [description] (default: {'checkpoint.pth.tar'})
98
+ """
99
+ weightpath = os.path.abspath(cfg.checkpoint)
100
+ print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(weightpath))
101
+ checkpoint = torch.load(weightpath)
102
+ start_epoch = checkpoint['epoch'] + 1
103
+ model.load_state_dict(checkpoint['state_dict'])
104
+ optimizer.load_state_dict(checkpoint['optimizer'])
105
+ scheduler.load_state_dict(checkpoint['scheduler'])
106
+ print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(weightpath))
107
+
108
+ return start_epoch
109
+
110
+
111
+ def save_checkpoint(epoch, model, optimizer, scheduler, filename='checkpoint.pth.tar'):
112
+ """[summary]
113
+ [description]
114
+ Arguments:
115
+ state {[type]} -- [description] a dict describe some params
116
+ Keyword Arguments:
117
+ filename {str} -- [description] (default: {'checkpoint.pth.tar'})
118
+ """
119
+ print('EAST <==> Save weight - epoch {} <==> Begin'.format(epoch))
120
+ state = {
121
+ 'epoch': epoch,
122
+ 'state_dict': model.state_dict(),
123
+ 'optimizer': optimizer.state_dict(),
124
+ 'scheduler': scheduler.state_dict()
125
+ }
126
+ weight_dir = cfg.save_model_path
127
+ if not os.path.exists(weight_dir):
128
+ os.mkdir(weight_dir)
129
+ filename = 'epoch_' + str(epoch) + '_checkpoint.pth.tar'
130
+ file_path = os.path.join(weight_dir, filename)
131
+ torch.save(state, file_path)
132
+ print('EAST <==> Save weight - epoch {} <==> Done'.format(epoch))
133
+
134
+
135
+ class Regularization(torch.nn.Module):
136
+ def __init__(self, model, weight_decay, p=2):
137
+ super(Regularization, self).__init__()
138
+ if weight_decay < 0:
139
+ print("param weight_decay can not <0")
140
+ exit(0)
141
+ self.model = model
142
+ self.weight_decay = weight_decay
143
+ self.p = p
144
+ self.weight_list = self.get_weight(model)
145
+ # self.weight_info(self.weight_list)
146
+
147
+ def to(self, device):
148
+ self.device = device
149
+ super().to(device)
150
+ return self
151
+
152
+ def forward(self, model):
153
+ self.weight_list = self.get_weight(model) # 获得最新的权重
154
+ reg_loss = self.regularization_loss(self.weight_list, self.weight_decay, p=self.p)
155
+ return reg_loss
156
+
157
+ def get_weight(self, model):
158
+ weight_list = []
159
+ for name, param in model.named_parameters():
160
+ if 'weight' in name:
161
+ weight = (name, param)
162
+ weight_list.append(weight)
163
+ return weight_list
164
+
165
+ def regularization_loss(self, weight_list, weight_decay, p=2):
166
+ reg_loss = 0
167
+ for name, w in weight_list:
168
+ l2_reg = torch.norm(w, p=p)
169
+ reg_loss = reg_loss + l2_reg
170
+
171
+ reg_loss = weight_decay * reg_loss
172
+ return reg_loss
173
+
174
+ def weight_info(self, weight_list):
175
+ print("---------------regularization weight---------------")
176
+ for name, w in weight_list:
177
+ print(name)
178
+ print("---------------------------------------------------")
179
+
180
+
181
+ def resize_image(im, max_side_len=2400):
182
+ '''
183
+ resize image to a size multiple of 32 which is required by the network
184
+ :param im: the resized image
185
+ :param max_side_len: limit of max image size to avoid out of memory in gpu
186
+ :return: the resized image and the resize ratio
187
+ '''
188
+ h, w, _ = im.shape
189
+
190
+ resize_w = w
191
+ resize_h = h
192
+
193
+ # limit the max side
194
+ """
195
+ if max(resize_h, resize_w) > max_side_len:
196
+ ratio = float(max_side_len) / resize_h if resize_h > resize_w else float(max_side_len) / resize_w
197
+ else:
198
+ ratio = 1.
199
+
200
+ resize_h = int(resize_h * ratio)
201
+ resize_w = int(resize_w * ratio)
202
+ """
203
+
204
+ resize_h = resize_h if resize_h % 32 == 0 else (resize_h // 32 - 1) * 32
205
+ resize_w = resize_w if resize_w % 32 == 0 else (resize_w // 32 - 1) * 32
206
+ #resize_h, resize_w = 512, 512
207
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
208
+
209
+ ratio_h = resize_h / float(h)
210
+ ratio_w = resize_w / float(w)
211
+
212
+ return im, (ratio_h, ratio_w)
213
+
214
+
215
+ def detect(score_map, geo_map, timer, score_map_thresh=0.8, box_thresh=0.1, nms_thres=0.2):
216
+ '''
217
+ restore text boxes from score map and geo map
218
+ :param score_map:
219
+ :param geo_map:
220
+ :param timer:
221
+ :param score_map_thresh: threshhold for score map
222
+ :param box_thresh: threshhold for boxes
223
+ :param nms_thres: threshold for nms
224
+ :return:
225
+ '''
226
+
227
+ # score_map 和 geo_map 的维数进行调整
228
+ if len(score_map.shape) == 4:
229
+ score_map = score_map[0, :, :, 0]
230
+ geo_map = geo_map[0, :, :, :]
231
+ # filter the score map
232
+ xy_text = np.argwhere(score_map > score_map_thresh)
233
+ # sort the text boxes via the y axis
234
+ xy_text = xy_text[np.argsort(xy_text[:, 0])]
235
+ # restore
236
+ start = time.time()
237
+ text_box_restored = preprossing.restore_rectangle(xy_text[:, ::-1] * 4,
238
+ geo_map[xy_text[:, 0], xy_text[:, 1], :]) # N*4*2
239
+ # print('{} text boxes before nms'.format(text_box_restored.shape[0]))
240
+ boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
241
+ boxes[:, :8] = text_box_restored.reshape((-1, 8))
242
+ boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
243
+ timer['restore'] = time.time() - start
244
+ # nms part
245
+ start = time.time()
246
+ boxes = locality_aware_nms.nms_locality(boxes.astype(np.float64), nms_thres)
247
+ timer['nms'] = time.time() - start
248
+ # print(timer['nms'])
249
+ if boxes.shape[0] == 0:
250
+ return None, timer
251
+
252
+ # here we filter some low score boxes by the average score map, this is different from the orginal paper
253
+ for i, box in enumerate(boxes):
254
+ mask = np.zeros_like(score_map, dtype=np.uint8)
255
+ cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1)
256
+ boxes[i, 8] = cv2.mean(score_map, mask)[0]
257
+ boxes = boxes[boxes[:, 8] > box_thresh]
258
+ return boxes, timer
259
+
260
+
261
+ def sort_poly(p):
262
+ min_axis = np.argmin(np.sum(p, axis=1))
263
+ p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]]
264
+ if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
265
+ return p
266
+ else:
267
+ return p[[0, 3, 2, 1]]
268
+
269
+
270
+ def mean_image_subtraction(images, means=cfg.means):
271
+ '''
272
+ image normalization
273
+ :param images: bs * w * h * channel
274
+ :param means:
275
+ :return:
276
+ '''
277
+ num_channels = images.data.shape[1]
278
+ if len(means) != num_channels:
279
+ raise ValueError('len(means) must match the number of channels')
280
+ for i in range(num_channels):
281
+ images.data[:, i, :, :] -= means[i]
282
+
283
+ return images
IndicPhotoOCR/ocr.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import torch
4
+ from PIL import Image
5
+ import cv2
6
+ import numpy as np
7
+
8
+
9
+ from IndicPhotoOCR.detection.east_detector import EASTdetector
10
+ from IndicPhotoOCR.script_identification.CLIP_identifier import CLIPidentifier
11
+ from IndicPhotoOCR.recognition.parseq_recogniser import PARseqrecogniser
12
+ import IndicPhotoOCR.detection.east_config as cfg
13
+
14
+
15
+ class OCR:
16
+ def __init__(self, device='cuda:0', verbose=False):
17
+ # self.detect_model_checkpoint = detect_model_checkpoint
18
+ self.device = device
19
+ self.verbose = verbose
20
+ # self.image_path = image_path
21
+ self.detector = EASTdetector()
22
+ self.recogniser = PARseqrecogniser()
23
+ self.identifier = CLIPidentifier()
24
+
25
+ def detect(self, image_path, detect_model_checkpoint=cfg.checkpoint):
26
+ """Run the detection model to get bounding boxes of text areas."""
27
+
28
+ if self.verbose:
29
+ print("Running text detection...")
30
+ detections = self.detector.detect(image_path, detect_model_checkpoint, self.device)
31
+ # print(detections)
32
+ return detections['detections']
33
+
34
+ def visualize_detection(self, image_path, detections, save_path=None, show=False):
35
+ # Default save path if none is provided
36
+ default_save_path = "test.png"
37
+ path_to_save = save_path if save_path is not None else default_save_path
38
+
39
+ # Get the directory part of the path
40
+ directory = os.path.dirname(path_to_save)
41
+
42
+ # Check if the directory exists, and create it if it doesn’t
43
+ if directory and not os.path.exists(directory):
44
+ os.makedirs(directory)
45
+ print(f"Created directory: {directory}")
46
+
47
+ # Read the image and draw bounding boxes
48
+ image = cv2.imread(image_path)
49
+ for box in detections:
50
+ # Convert list of points to a numpy array with int type
51
+ points = np.array(box, np.int32)
52
+ points = points.reshape((-1, 1, 2)) # Reshape for cv2.polylines
53
+ # Draw the polygon
54
+ cv2.polylines(image, [points], isClosed=True, color=(0, 255, 0), thickness=3)
55
+
56
+ # Show the image if 'show' is True
57
+ if show:
58
+ plt.figure(figsize=(10, 10))
59
+ plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
60
+ plt.axis("off")
61
+ plt.show()
62
+
63
+ # Save the annotated image
64
+ cv2.imwrite(path_to_save, image)
65
+ print(f"Image saved at: {path_to_save}")
66
+
67
+ def crop_and_identify_script(self, image, bbox):
68
+ """
69
+ Crop a text area from the image and identify its script language.
70
+
71
+ Args:
72
+ image (PIL.Image): The full image.
73
+ bbox (list): List of four corner points, each a [x, y] pair.
74
+
75
+ Returns:
76
+ str: Identified script language.
77
+ """
78
+ # Extract x and y coordinates from the four corner points
79
+ x_coords = [point[0] for point in bbox]
80
+ y_coords = [point[1] for point in bbox]
81
+
82
+ # Get the bounding box coordinates (min and max)
83
+ x_min, y_min = min(x_coords), min(y_coords)
84
+ x_max, y_max = max(x_coords), max(y_coords)
85
+
86
+ # Crop the image based on the bounding box
87
+ cropped_image = image.crop((x_min, y_min, x_max, y_max))
88
+ root_image_dir = "IndicPhotoOCR/script_identification"
89
+ os.makedirs(f"{root_image_dir}/images", exist_ok=True)
90
+ # Temporarily save the cropped image to pass to the script model
91
+ cropped_path = f'{root_image_dir}/images/temp_crop_{x_min}_{y_min}.jpg'
92
+ cropped_image.save(cropped_path)
93
+
94
+ # Predict script language, here we assume "hindi" as the model name
95
+ if self.verbose:
96
+ print("Identifying script for the cropped area...")
97
+ script_lang = self.identifier.identify(cropped_path, "hindi") # Use "hindi" as the model name
98
+ # print(script_lang)
99
+
100
+ # Clean up temporary file
101
+ # os.remove(cropped_path)
102
+
103
+ return script_lang, cropped_path
104
+
105
+ def recognise(self, cropped_image_path, script_lang):
106
+ """Recognize text in a cropped image area using the identified script."""
107
+ if self.verbose:
108
+ print("Recognizing text in detected area...")
109
+ recognized_text = self.recogniser.recognise(script_lang, cropped_image_path, script_lang, self.verbose)
110
+ # print(recognized_text)
111
+ return recognized_text
112
+
113
+ def ocr(self, image_path):
114
+ """Process the image by detecting text areas, identifying script, and recognizing text."""
115
+ recognized_words = []
116
+ image = Image.open(image_path)
117
+
118
+ # Run detection
119
+ detections = self.detect(image_path)
120
+
121
+ # Process each detected text area
122
+ for bbox in detections:
123
+ # Crop and identify script language
124
+ script_lang, cropped_path = self.crop_and_identify_script(image, bbox)
125
+
126
+ # Check if the script language is valid
127
+ if script_lang:
128
+
129
+ # Recognize text
130
+ recognized_word = self.recognise(cropped_path, script_lang)
131
+ recognized_words.append(recognized_word)
132
+
133
+ if self.verbose:
134
+ print(f"Recognized word: {recognized_word}")
135
+
136
+ return recognized_words
137
+
138
+ if __name__ == '__main__':
139
+ # detect_model_checkpoint = 'bharatSTR/East/tmp/epoch_990_checkpoint.pth.tar'
140
+ sample_image_path = 'test_images/image_141.jpg'
141
+ cropped_image_path = 'test_images/cropped_image/image_141_0.jpg'
142
+
143
+ ocr = OCR(device="cpu", verbose=False)
144
+
145
+ # detections = ocr.detect(sample_image_path)
146
+ # print(detections)
147
+
148
+ # ocr.visualize_detection(sample_image_path, detections)
149
+
150
+ # recognition = ocr.recognise(cropped_image_path, "hindi")
151
+ # print(recognition)
152
+
153
+ recognised_words = ocr.ocr(sample_image_path)
154
+ print(recognised_words)
IndicPhotoOCR/recognition/__init__.py ADDED
File without changes
IndicPhotoOCR/recognition/parseq_recogniser.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ # import fire
3
+ import json
4
+ import numpy as np
5
+ import os
6
+ # import pandas as pd
7
+ import sys
8
+ import torch
9
+ import requests
10
+
11
+ from dataclasses import dataclass
12
+ from PIL import Image
13
+ from nltk import edit_distance
14
+ from torchvision import transforms as T
15
+ from typing import Optional, Callable, Sequence, Tuple
16
+ from tqdm import tqdm
17
+
18
+
19
+ from IndicPhotoOCR.utils.strhub.data.module import SceneTextDataModule
20
+ from IndicPhotoOCR.utils.strhub.models.utils import load_from_checkpoint
21
+
22
+
23
+ model_info = {
24
+ "assamese": {
25
+ "path": "models/assamese.ckpt",
26
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/assamese.ckpt",
27
+ },
28
+ "bengali": {
29
+ "path": "models/bengali.ckpt",
30
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/bengali.ckpt",
31
+ },
32
+ "hindi": {
33
+ "path": "models/hindi.ckpt",
34
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/hindi.ckpt",
35
+ },
36
+ "gujarati": {
37
+ "path": "models/gujarati.ckpt",
38
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/gujarati.ckpt",
39
+ },
40
+ "marathi": {
41
+ "path": "models/marathi.ckpt",
42
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/marathi.ckpt",
43
+ },
44
+ "odia": {
45
+ "path": "models/odia.ckpt",
46
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/odia.ckpt",
47
+ },
48
+ "punjabi": {
49
+ "path": "models/punjabi.ckpt",
50
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/punjabi.ckpt",
51
+ },
52
+ "tamil": {
53
+ "path": "models/tamil.ckpt",
54
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/tamil.ckpt",
55
+ },
56
+ "telugu": {
57
+ "path": "models/telugu.ckpt",
58
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/telugu.ckpt",
59
+ }
60
+ }
61
+
62
+ class PARseqrecogniser:
63
+ def __init__(self):
64
+ pass
65
+
66
+ def get_transform(self, img_size: Tuple[int], augment: bool = False, rotation: int = 0):
67
+ transforms = []
68
+ if augment:
69
+ from .augment import rand_augment_transform
70
+ transforms.append(rand_augment_transform())
71
+ if rotation:
72
+ transforms.append(lambda img: img.rotate(rotation, expand=True))
73
+ transforms.extend([
74
+ T.Resize(img_size, T.InterpolationMode.BICUBIC),
75
+ T.ToTensor(),
76
+ T.Normalize(0.5, 0.5)
77
+ ])
78
+ return T.Compose(transforms)
79
+
80
+
81
+ def load_model(self, device, checkpoint):
82
+ model = load_from_checkpoint(checkpoint).eval().to(device)
83
+ return model
84
+
85
+ def get_model_output(self, device, model, image_path):
86
+ hp = model.hparams
87
+ transform = self.get_transform(hp.img_size, rotation=0)
88
+
89
+ image_name = image_path.split("/")[-1]
90
+ img = Image.open(image_path).convert('RGB')
91
+ img = transform(img)
92
+ logits = model(img.unsqueeze(0).to(device))
93
+ probs = logits.softmax(-1)
94
+ preds, probs = model.tokenizer.decode(probs)
95
+ text = model.charset_adapter(preds[0])
96
+ scores = probs[0].detach().cpu().numpy()
97
+
98
+ return text
99
+
100
+ # Ensure model file exists; download directly if not
101
+ def ensure_model(self, model_name):
102
+ model_path = model_info[model_name]["path"]
103
+ url = model_info[model_name]["url"]
104
+ root_model_dir = "IndicPhotoOCR/recognition/"
105
+ model_path = os.path.join(root_model_dir, model_path)
106
+
107
+ if not os.path.exists(model_path):
108
+ print(f"Model not found locally. Downloading {model_name} from {url}...")
109
+
110
+ # Start the download with a progress bar
111
+ response = requests.get(url, stream=True)
112
+ total_size = int(response.headers.get('content-length', 0))
113
+ os.makedirs(f"{root_model_dir}/models", exist_ok=True)
114
+
115
+ with open(model_path, "wb") as f, tqdm(
116
+ desc=model_name,
117
+ total=total_size,
118
+ unit='B',
119
+ unit_scale=True,
120
+ unit_divisor=1024,
121
+ ) as bar:
122
+ for data in response.iter_content(chunk_size=1024):
123
+ f.write(data)
124
+ bar.update(len(data))
125
+
126
+ print(f"Downloaded model for {model_name}.")
127
+
128
+ return model_path
129
+
130
+ def bstr(checkpoint, language, image_dir, save_dir):
131
+ """
132
+ Runs the OCR model to process images and save the output as a JSON file.
133
+
134
+ Args:
135
+ checkpoint (str): Path to the model checkpoint file.
136
+ language (str): Language code (e.g., 'hindi', 'english').
137
+ image_dir (str): Directory containing the images to process.
138
+ save_dir (str): Directory where the output JSON file will be saved.
139
+
140
+ Example usage:
141
+ python your_script.py --checkpoint /path/to/checkpoint.ckpt --language hindi --image_dir /path/to/images --save_dir /path/to/save
142
+ """
143
+ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
144
+
145
+ if language != "english":
146
+ model = load_model(device, checkpoint)
147
+ else:
148
+ model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().to(device)
149
+
150
+ parseq_dict = {}
151
+ for image_path in tqdm(os.listdir(image_dir)):
152
+ assert os.path.exists(os.path.join(image_dir, image_path)) == True, f"{image_path}"
153
+ text = get_model_output(device, model, os.path.join(image_dir, image_path), language=f"{language}")
154
+
155
+ filename = image_path.split('/')[-1]
156
+ parseq_dict[filename] = text
157
+
158
+ os.makedirs(save_dir, exist_ok=True)
159
+ with open(f"{save_dir}/{language}_test.json", 'w') as json_file:
160
+ json.dump(parseq_dict, json_file, indent=4, ensure_ascii=False)
161
+
162
+
163
+ def bstr_onImage(checkpoint, language, image_path):
164
+ """
165
+ Runs the OCR model to process images and save the output as a JSON file.
166
+
167
+ Args:
168
+ checkpoint (str): Path to the model checkpoint file.
169
+ language (str): Language code (e.g., 'hindi', 'english').
170
+ image_dir (str): Directory containing the images to process.
171
+ save_dir (str): Directory where the output JSON file will be saved.
172
+
173
+ Example usage:
174
+ python your_script.py --checkpoint /path/to/checkpoint.ckpt --language hindi --image_dir /path/to/images --save_dir /path/to/save
175
+ """
176
+ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
177
+
178
+ if language != "english":
179
+ model = load_model(device, checkpoint)
180
+ else:
181
+ model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().to(device)
182
+
183
+ # parseq_dict = {}
184
+ # for image_path in tqdm(os.listdir(image_dir)):
185
+ # assert os.path.exists(os.path.join(image_dir, image_path)) == True, f"{image_path}"
186
+ text = get_model_output(device, model, image_path, language=f"{language}")
187
+
188
+ return text
189
+
190
+
191
+ def recognise(self, checkpoint: str, image_path: str, language: str, verbose: bool) -> str:
192
+ """
193
+ Loads the desired model and returns the recognized word from the specified image.
194
+
195
+ Args:
196
+ checkpoint (str): Path to the model checkpoint file.
197
+ language (str): Language code (e.g., 'hindi', 'english').
198
+ image_path (str): Path to the image for which text recognition is needed.
199
+
200
+ Returns:
201
+ str: The recognized text from the image.
202
+ """
203
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
204
+
205
+ if language != "english":
206
+ model_path = self.ensure_model(checkpoint)
207
+ model = self.load_model(device, model_path)
208
+ else:
209
+ model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True, verbose=verbose).eval().to(device)
210
+
211
+ recognized_text = self.get_model_output(device, model, image_path)
212
+
213
+ return recognized_text
214
+ # if __name__ == '__main__':
215
+ # fire.Fire(main)
IndicPhotoOCR/script_identification/CLIP_identifier.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import clip
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ import os
7
+ import requests
8
+
9
+ # Model information dictionary containing model paths and language subcategories
10
+ model_info = {
11
+ "hindi": {
12
+ "path": "models/clip_finetuned_hindienglish_real.pth",
13
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglish_real.pth",
14
+ "subcategories": ["hindi", "english"]
15
+ },
16
+ "hinengasm": {
17
+ "path": "models/clip_finetuned_hindienglishassamese_real.pth",
18
+ "url": "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishassamese_real.pth",
19
+ "subcategories": ["hindi", "english", "assamese"]
20
+ },
21
+ "hinengben": {
22
+ "path": "models/clip_finetuned_hindienglishbengali_real.pth",
23
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishbengali_real.pth",
24
+ "subcategories": ["hindi", "english", "bengali"]
25
+ },
26
+ "hinengguj": {
27
+ "path": "models/clip_finetuned_hindienglishgujarati_real.pth",
28
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishgujarati_real.pth",
29
+ "subcategories": ["hindi", "english", "gujarati"]
30
+ },
31
+ "hinengkan": {
32
+ "path": "models/clip_finetuned_hindienglishkannada_real.pth",
33
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishkannada_real.pth",
34
+ "subcategories": ["hindi", "english", "kannada"]
35
+ },
36
+ "hinengmal": {
37
+ "path": "models/clip_finetuned_hindienglishmalayalam_real.pth",
38
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmalayalam_real.pth",
39
+ "subcategories": ["hindi", "english", "malayalam"]
40
+ },
41
+ "hinengmar": {
42
+ "path": "models/clip_finetuned_hindienglishmarathi_real.pth",
43
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmarathi_real.pth",
44
+ "subcategories": ["hindi", "english", "marathi"]
45
+ },
46
+ "hinengmei": {
47
+ "path": "models/clip_finetuned_hindienglishmeitei_real.pth",
48
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmeitei_real.pth",
49
+ "subcategories": ["hindi", "english", "meitei"]
50
+ },
51
+ "hinengodi": {
52
+ "path": "models/clip_finetuned_hindienglishodia_real.pth",
53
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishodia_real.pth",
54
+ "subcategories": ["hindi", "english", "odia"]
55
+ },
56
+ "hinengpun": {
57
+ "path": "models/clip_finetuned_hindienglishpunjabi_real.pth",
58
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishpunjabi_real.pth",
59
+ "subcategories": ["hindi", "english", "punjabi"]
60
+ },
61
+ "hinengtam": {
62
+ "path": "models/clip_finetuned_hindienglishtamil_real.pth",
63
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtamil_real.pth",
64
+ "subcategories": ["hindi", "english", "tamil"]
65
+ },
66
+ "hinengtel": {
67
+ "path": "models/clip_finetuned_hindienglishtelugu_real.pth",
68
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtelugu_real.pth",
69
+ "subcategories": ["hindi", "english", "telugu"]
70
+ },
71
+ "hinengurd": {
72
+ "path": "models/clip_finetuned_hindienglishurdu_real.pth",
73
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishurdu_real.pth",
74
+ "subcategories": ["hindi", "english", "urdu"]
75
+ },
76
+
77
+
78
+ }
79
+
80
+
81
+ # Set device to CUDA if available, otherwise use CPU
82
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
+ clip_model, preprocess = clip.load("ViT-B/32", device=device)
84
+
85
+ class CLIPFineTuner(torch.nn.Module):
86
+ """
87
+ Fine-tuning class for the CLIP model to adapt to specific tasks.
88
+
89
+ Attributes:
90
+ model (torch.nn.Module): The CLIP model to be fine-tuned.
91
+ classifier (torch.nn.Linear): A linear classifier to map features to the desired number of classes.
92
+ """
93
+ def __init__(self, model, num_classes):
94
+ """
95
+ Initializes the fine-tuner with the CLIP model and classifier.
96
+
97
+ Args:
98
+ model (torch.nn.Module): The base CLIP model.
99
+ num_classes (int): The number of target classes for classification.
100
+ """
101
+ super(CLIPFineTuner, self).__init__()
102
+ self.model = model
103
+ self.classifier = torch.nn.Linear(model.visual.output_dim, num_classes)
104
+
105
+ def forward(self, x):
106
+ """
107
+ Forward pass for image classification.
108
+
109
+ Args:
110
+ x (torch.Tensor): Preprocessed input tensor for an image.
111
+
112
+ Returns:
113
+ torch.Tensor: Logits for each class.
114
+ """
115
+ with torch.no_grad():
116
+ features = self.model.encode_image(x).float() # Extract image features from CLIP model
117
+ return self.classifier(features) # Return class logits
118
+
119
+ class CLIPidentifier:
120
+ def __init__(self):
121
+ pass
122
+
123
+ # Ensure model file exists; download directly if not
124
+ def ensure_model(self, model_name):
125
+ model_path = model_info[model_name]["path"]
126
+ url = model_info[model_name]["url"]
127
+ root_model_dir = "IndicPhotoOCR/script_identification/"
128
+ model_path = os.path.join(root_model_dir, model_path)
129
+
130
+ if not os.path.exists(model_path):
131
+ print(f"Model not found locally. Downloading {model_name} from {url}...")
132
+ response = requests.get(url, stream=True)
133
+ os.makedirs(f"{root_model_dir}/models", exist_ok=True)
134
+ with open(f"{model_path}", "wb") as f:
135
+ f.write(response.content)
136
+ print(f"Downloaded model for {model_name}.")
137
+
138
+ return model_path
139
+
140
+ # Prediction function to verify and load the model
141
+ def identify(self, image_path, model_name):
142
+ """
143
+ Predicts the class of an input image using a fine-tuned CLIP model.
144
+
145
+ Args:
146
+ image_path (str): Path to the input image file.
147
+ model_name (str): Name of the model (e.g., hineng, hinengpun, hinengguj) as specified in `model_info`.
148
+
149
+ Returns:
150
+ dict: Contains either `predicted_class` if successful or `error` if an exception occurs.
151
+
152
+ Example usage:
153
+ result = predict("sample_image.jpg", "hinengguj")
154
+ print(result) # Output might be {'predicted_class': 'hindi'}
155
+ """
156
+ try:
157
+ # Validate model name and retrieve associated subcategories
158
+ if model_name not in model_info:
159
+ return {"error": "Invalid model name"}
160
+
161
+ # Ensure the model file is downloaded and accessible
162
+ model_path = self.ensure_model(model_name)
163
+
164
+
165
+ subcategories = model_info[model_name]["subcategories"]
166
+ num_classes = len(subcategories)
167
+
168
+ # Load the fine-tuned model with the specified number of classes
169
+ model_ft = CLIPFineTuner(clip_model, num_classes)
170
+ model_ft.load_state_dict(torch.load(model_path, map_location=device))
171
+ model_ft = model_ft.to(device)
172
+ model_ft.eval()
173
+
174
+ # Load and preprocess the image
175
+ image = Image.open(image_path).convert("RGB")
176
+ input_tensor = preprocess(image).unsqueeze(0).to(device)
177
+
178
+ # Run the model and get the prediction
179
+ outputs = model_ft(input_tensor)
180
+ _, predicted_idx = torch.max(outputs, 1)
181
+ predicted_class = subcategories[predicted_idx.item()]
182
+
183
+ return predicted_class
184
+
185
+ except Exception as e:
186
+ return {"error": str(e)}
187
+
188
+
189
+ # if __name__ == "__main__":
190
+ # import argparse
191
+
192
+ # # Argument parser for command line usage
193
+ # parser = argparse.ArgumentParser(description="Image classification using CLIP fine-tuned model")
194
+ # parser.add_argument("image_path", type=str, help="Path to the input image")
195
+ # parser.add_argument("model_name", type=str, choices=model_info.keys(), help="Name of the model (e.g., hineng, hinengpun, hinengguj)")
196
+
197
+ # args = parser.parse_args()
198
+
199
+ # # Execute prediction with command line inputs
200
+ # result = predict(args.image_path, args.model_name)
201
+ # print(result)
IndicPhotoOCR/script_identification/__init__.py ADDED
File without changes
IndicPhotoOCR/theme.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Iterable
3
+ import gradio as gr
4
+ from gradio.themes.base import Base
5
+ from gradio.themes.utils import colors, fonts, sizes
6
+ import time
7
+
8
+
9
+ class Seafoam(Base):
10
+ def __init__(
11
+ self,
12
+ *,
13
+ primary_hue: colors.Color | str = colors.emerald,
14
+ secondary_hue: colors.Color | str = colors.blue,
15
+ neutral_hue: colors.Color | str = colors.gray,
16
+ spacing_size: sizes.Size | str = sizes.spacing_md,
17
+ radius_size: sizes.Size | str = sizes.radius_md,
18
+ text_size: sizes.Size | str = sizes.text_lg,
19
+ font: fonts.Font
20
+ | str
21
+ | Iterable[fonts.Font | str] = (
22
+ fonts.GoogleFont("Quicksand"),
23
+ "ui-sans-serif",
24
+ "sans-serif",
25
+ ),
26
+ font_mono: fonts.Font
27
+ | str
28
+ | Iterable[fonts.Font | str] = (
29
+ fonts.GoogleFont("IBM Plex Mono"),
30
+ "ui-monospace",
31
+ "monospace",
32
+ ),
33
+ ):
34
+ super().__init__(
35
+ primary_hue=primary_hue,
36
+ secondary_hue=secondary_hue,
37
+ neutral_hue=neutral_hue,
38
+ spacing_size=spacing_size,
39
+ radius_size=radius_size,
40
+ text_size=text_size,
41
+ font=font,
42
+ font_mono=font_mono,
43
+ )
IndicPhotoOCR/utils/strhub/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # from data.module import SceneTextDataModule
2
+ # from model.utils import load_from_checkpoint
IndicPhotoOCR/utils/strhub/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # from .module import SceneTextDataModule
IndicPhotoOCR/utils/strhub/data/aa_overrides.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Extends default ops to accept optional parameters."""
17
+ from functools import partial
18
+
19
+ from timm.data.auto_augment import _LEVEL_DENOM, LEVEL_TO_ARG, NAME_TO_OP, _randomly_negate, rotate
20
+
21
+
22
+ def rotate_expand(img, degrees, **kwargs):
23
+ """Rotate operation with expand=True to avoid cutting off the characters"""
24
+ kwargs['expand'] = True
25
+ return rotate(img, degrees, **kwargs)
26
+
27
+
28
+ def _level_to_arg(level, hparams, key, default):
29
+ magnitude = hparams.get(key, default)
30
+ level = (level / _LEVEL_DENOM) * magnitude
31
+ level = _randomly_negate(level)
32
+ return (level,)
33
+
34
+
35
+ def apply():
36
+ # Overrides
37
+ NAME_TO_OP.update({
38
+ 'Rotate': rotate_expand,
39
+ })
40
+ LEVEL_TO_ARG.update({
41
+ 'Rotate': partial(_level_to_arg, key='rotate_deg', default=30.0),
42
+ 'ShearX': partial(_level_to_arg, key='shear_x_pct', default=0.3),
43
+ 'ShearY': partial(_level_to_arg, key='shear_y_pct', default=0.3),
44
+ 'TranslateXRel': partial(_level_to_arg, key='translate_x_pct', default=0.45),
45
+ 'TranslateYRel': partial(_level_to_arg, key='translate_y_pct', default=0.45),
46
+ })
IndicPhotoOCR/utils/strhub/data/augment.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+
18
+ import imgaug.augmenters as iaa
19
+ import numpy as np
20
+ from PIL import Image, ImageFilter
21
+
22
+ from timm.data import auto_augment
23
+
24
+ from strhub.data import aa_overrides
25
+
26
+ aa_overrides.apply()
27
+
28
+ _OP_CACHE = {}
29
+
30
+
31
+ def _get_op(key, factory):
32
+ try:
33
+ op = _OP_CACHE[key]
34
+ except KeyError:
35
+ op = factory()
36
+ _OP_CACHE[key] = op
37
+ return op
38
+
39
+
40
+ def _get_param(level, img, max_dim_factor, min_level=1):
41
+ max_level = max(min_level, max_dim_factor * max(img.size))
42
+ return round(min(level, max_level))
43
+
44
+
45
+ def gaussian_blur(img, radius, **__):
46
+ radius = _get_param(radius, img, 0.02)
47
+ key = 'gaussian_blur_' + str(radius)
48
+ op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius))
49
+ return img.filter(op)
50
+
51
+
52
+ def motion_blur(img, k, **__):
53
+ k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values
54
+ key = 'motion_blur_' + str(k)
55
+ op = _get_op(key, lambda: iaa.MotionBlur(k))
56
+ return Image.fromarray(op(image=np.asarray(img)))
57
+
58
+
59
+ def gaussian_noise(img, scale, **_):
60
+ scale = _get_param(scale, img, 0.25) | 1 # bin to odd values
61
+ key = 'gaussian_noise_' + str(scale)
62
+ op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale))
63
+ return Image.fromarray(op(image=np.asarray(img)))
64
+
65
+
66
+ def poisson_noise(img, lam, **_):
67
+ lam = _get_param(lam, img, 0.2) | 1 # bin to odd values
68
+ key = 'poisson_noise_' + str(lam)
69
+ op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam))
70
+ return Image.fromarray(op(image=np.asarray(img)))
71
+
72
+
73
+ def _level_to_arg(level, _hparams, max):
74
+ level = max * level / auto_augment._LEVEL_DENOM
75
+ return (level,)
76
+
77
+
78
+ _RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy()
79
+ _RAND_TRANSFORMS.remove('SharpnessIncreasing') # remove, interferes with *blur ops
80
+ _RAND_TRANSFORMS.extend([
81
+ 'GaussianBlur',
82
+ # 'MotionBlur',
83
+ # 'GaussianNoise',
84
+ 'PoissonNoise',
85
+ ])
86
+ auto_augment.LEVEL_TO_ARG.update({
87
+ 'GaussianBlur': partial(_level_to_arg, max=4),
88
+ 'MotionBlur': partial(_level_to_arg, max=20),
89
+ 'GaussianNoise': partial(_level_to_arg, max=0.1 * 255),
90
+ 'PoissonNoise': partial(_level_to_arg, max=40),
91
+ })
92
+ auto_augment.NAME_TO_OP.update({
93
+ 'GaussianBlur': gaussian_blur,
94
+ 'MotionBlur': motion_blur,
95
+ 'GaussianNoise': gaussian_noise,
96
+ 'PoissonNoise': poisson_noise,
97
+ })
98
+
99
+
100
+ def rand_augment_transform(magnitude=5, num_layers=3):
101
+ # These are tuned for magnitude=5, which means that effective magnitudes are half of these values.
102
+ hparams = {
103
+ 'rotate_deg': 30,
104
+ 'shear_x_pct': 0.9,
105
+ 'shear_y_pct': 0.2,
106
+ 'translate_x_pct': 0.10,
107
+ 'translate_y_pct': 0.30,
108
+ }
109
+ ra_ops = auto_augment.rand_augment_ops(magnitude, hparams=hparams, transforms=_RAND_TRANSFORMS)
110
+ # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice)
111
+ choice_weights = [1.0 / len(ra_ops) for _ in range(len(ra_ops))]
112
+ return auto_augment.RandAugment(ra_ops, num_layers, choice_weights)
IndicPhotoOCR/utils/strhub/data/dataset.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import glob
16
+ import io
17
+ import logging
18
+ import unicodedata
19
+ from pathlib import Path, PurePath
20
+ from typing import Callable, Optional, Union
21
+
22
+ import lmdb
23
+ from PIL import Image
24
+
25
+ from torch.utils.data import ConcatDataset, Dataset
26
+
27
+ from IndicPhotoOCR.utils.strhub.data.utils import CharsetAdapter
28
+
29
+ log = logging.getLogger(__name__)
30
+
31
+
32
+ def build_tree_dataset(root: Union[PurePath, str], *args, **kwargs):
33
+ try:
34
+ kwargs.pop('root') # prevent 'root' from being passed via kwargs
35
+ except KeyError:
36
+ pass
37
+ root = Path(root).absolute()
38
+ log.info(f'dataset root:\t{root}')
39
+ datasets = []
40
+ for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True):
41
+ mdb = Path(mdb)
42
+ ds_name = str(mdb.parent.relative_to(root))
43
+ ds_root = str(mdb.parent.absolute())
44
+ dataset = LmdbDataset(ds_root, *args, **kwargs)
45
+ log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}')
46
+ datasets.append(dataset)
47
+ return ConcatDataset(datasets)
48
+
49
+
50
+ class LmdbDataset(Dataset):
51
+ """Dataset interface to an LMDB database.
52
+
53
+ It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned
54
+ as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset.
55
+ Labels are transformed according to the charset.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ root: str,
61
+ charset: str,
62
+ max_label_len: int,
63
+ min_image_dim: int = 0,
64
+ remove_whitespace: bool = True,
65
+ normalize_unicode: bool = True,
66
+ unlabelled: bool = False,
67
+ transform: Optional[Callable] = None,
68
+ ):
69
+ self._env = None
70
+ self.root = root
71
+ self.unlabelled = unlabelled
72
+ self.transform = transform
73
+ self.labels = []
74
+ self.filtered_index_list = []
75
+ self.num_samples = self._preprocess_labels(
76
+ charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim
77
+ )
78
+
79
+ def __del__(self):
80
+ if self._env is not None:
81
+ self._env.close()
82
+ self._env = None
83
+
84
+ def _create_env(self):
85
+ return lmdb.open(
86
+ self.root, max_readers=1, readonly=True, create=False, readahead=False, meminit=False, lock=False
87
+ )
88
+
89
+ @property
90
+ def env(self):
91
+ if self._env is None:
92
+ self._env = self._create_env()
93
+ return self._env
94
+
95
+ def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim):
96
+ charset_adapter = CharsetAdapter(charset)
97
+ with self._create_env() as env, env.begin() as txn:
98
+ num_samples = int(txn.get('num-samples'.encode()))
99
+ if self.unlabelled:
100
+ return num_samples
101
+ for index in range(num_samples):
102
+ index += 1 # lmdb starts with 1
103
+ label_key = f'label-{index:09d}'.encode()
104
+ label = txn.get(label_key).decode()
105
+ # Normally, whitespace is removed from the labels.
106
+ if remove_whitespace:
107
+ label = ''.join(label.split())
108
+ # Normalize unicode composites (if any) and convert to compatible ASCII characters
109
+ if normalize_unicode:
110
+ label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode()
111
+ # Filter by length before removing unsupported characters. The original label might be too long.
112
+ if len(label) > max_label_len:
113
+ continue
114
+ label = charset_adapter(label)
115
+ # We filter out samples which don't contain any supported characters
116
+ if not label:
117
+ continue
118
+ # Filter images that are too small.
119
+ if min_image_dim > 0:
120
+ img_key = f'image-{index:09d}'.encode()
121
+ buf = io.BytesIO(txn.get(img_key))
122
+ w, h = Image.open(buf).size
123
+ if w < self.min_image_dim or h < self.min_image_dim:
124
+ continue
125
+ self.labels.append(label)
126
+ self.filtered_index_list.append(index)
127
+ return len(self.labels)
128
+
129
+ def __len__(self):
130
+ return self.num_samples
131
+
132
+ def __getitem__(self, index):
133
+ if self.unlabelled:
134
+ label = index
135
+ else:
136
+ label = self.labels[index]
137
+ index = self.filtered_index_list[index]
138
+
139
+ img_key = f'image-{index:09d}'.encode()
140
+ with self.env.begin() as txn:
141
+ imgbuf = txn.get(img_key)
142
+ buf = io.BytesIO(imgbuf)
143
+ img = Image.open(buf).convert('RGB')
144
+
145
+ if self.transform is not None:
146
+ img = self.transform(img)
147
+
148
+ return img, label
IndicPhotoOCR/utils/strhub/data/module.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pathlib import PurePath
17
+ from typing import Callable, Optional, Sequence
18
+
19
+ from torch.utils.data import DataLoader
20
+ from torchvision import transforms as T
21
+
22
+ import pytorch_lightning as pl
23
+
24
+ from IndicPhotoOCR.utils.strhub.data.dataset import LmdbDataset, build_tree_dataset
25
+
26
+
27
+ class SceneTextDataModule(pl.LightningDataModule):
28
+ TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80')
29
+ TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80')
30
+ TEST_NEW = ('ArT', 'COCOv1.4', 'Uber')
31
+ TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW))
32
+
33
+ def __init__(
34
+ self,
35
+ root_dir: str,
36
+ train_dir: str,
37
+ img_size: Sequence[int],
38
+ max_label_length: int,
39
+ charset_train: str,
40
+ charset_test: str,
41
+ batch_size: int,
42
+ num_workers: int,
43
+ augment: bool,
44
+ remove_whitespace: bool = True,
45
+ normalize_unicode: bool = True,
46
+ min_image_dim: int = 0,
47
+ rotation: int = 0,
48
+ collate_fn: Optional[Callable] = None,
49
+ ):
50
+ super().__init__()
51
+ self.root_dir = root_dir
52
+ self.train_dir = train_dir
53
+ self.img_size = tuple(img_size)
54
+ self.max_label_length = max_label_length
55
+ self.charset_train = charset_train
56
+ self.charset_test = charset_test
57
+ self.batch_size = batch_size
58
+ self.num_workers = num_workers
59
+ self.augment = augment
60
+ self.remove_whitespace = remove_whitespace
61
+ self.normalize_unicode = normalize_unicode
62
+ self.min_image_dim = min_image_dim
63
+ self.rotation = rotation
64
+ self.collate_fn = collate_fn
65
+ self._train_dataset = None
66
+ self._val_dataset = None
67
+
68
+ @staticmethod
69
+ def get_transform(img_size: tuple[int], augment: bool = False, rotation: int = 0):
70
+ transforms = []
71
+ if augment:
72
+ from .augment import rand_augment_transform
73
+
74
+ transforms.append(rand_augment_transform())
75
+ if rotation:
76
+ transforms.append(lambda img: img.rotate(rotation, expand=True))
77
+ transforms.extend([
78
+ T.Resize(img_size, T.InterpolationMode.BICUBIC),
79
+ T.ToTensor(),
80
+ T.Normalize(0.5, 0.5),
81
+ ])
82
+ return T.Compose(transforms)
83
+
84
+ @property
85
+ def train_dataset(self):
86
+ if self._train_dataset is None:
87
+ transform = self.get_transform(self.img_size, self.augment)
88
+ root = PurePath(self.root_dir, 'train', self.train_dir)
89
+ self._train_dataset = build_tree_dataset(
90
+ root,
91
+ self.charset_train,
92
+ self.max_label_length,
93
+ self.min_image_dim,
94
+ self.remove_whitespace,
95
+ self.normalize_unicode,
96
+ transform=transform,
97
+ )
98
+ return self._train_dataset
99
+
100
+ @property
101
+ def val_dataset(self):
102
+ if self._val_dataset is None:
103
+ transform = self.get_transform(self.img_size)
104
+ root = PurePath(self.root_dir, 'val')
105
+ self._val_dataset = build_tree_dataset(
106
+ root,
107
+ self.charset_test,
108
+ self.max_label_length,
109
+ self.min_image_dim,
110
+ self.remove_whitespace,
111
+ self.normalize_unicode,
112
+ transform=transform,
113
+ )
114
+ return self._val_dataset
115
+
116
+ def train_dataloader(self):
117
+ return DataLoader(
118
+ self.train_dataset,
119
+ batch_size=self.batch_size,
120
+ shuffle=True,
121
+ num_workers=self.num_workers,
122
+ persistent_workers=self.num_workers > 0,
123
+ pin_memory=True,
124
+ collate_fn=self.collate_fn,
125
+ )
126
+
127
+ def val_dataloader(self):
128
+ return DataLoader(
129
+ self.val_dataset,
130
+ batch_size=self.batch_size,
131
+ num_workers=self.num_workers,
132
+ persistent_workers=self.num_workers > 0,
133
+ pin_memory=True,
134
+ collate_fn=self.collate_fn,
135
+ )
136
+
137
+ def test_dataloaders(self, subset):
138
+ transform = self.get_transform(self.img_size, rotation=self.rotation)
139
+ root = PurePath(self.root_dir, 'test')
140
+ datasets = {
141
+ s: LmdbDataset(
142
+ str(root / s),
143
+ self.charset_test,
144
+ self.max_label_length,
145
+ self.min_image_dim,
146
+ self.remove_whitespace,
147
+ self.normalize_unicode,
148
+ transform=transform,
149
+ )
150
+ for s in subset
151
+ }
152
+ return {
153
+ k: DataLoader(
154
+ v, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=self.collate_fn
155
+ )
156
+ for k, v in datasets.items()
157
+ }
IndicPhotoOCR/utils/strhub/data/utils.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from abc import ABC, abstractmethod
18
+ from itertools import groupby
19
+ from typing import Optional
20
+
21
+ import torch
22
+ from torch import Tensor
23
+ from torch.nn.utils.rnn import pad_sequence
24
+
25
+
26
+ class CharsetAdapter:
27
+ """Transforms labels according to the target charset."""
28
+
29
+ def __init__(self, target_charset) -> None:
30
+ super().__init__()
31
+ self.lowercase_only = target_charset == target_charset.lower()
32
+ self.uppercase_only = target_charset == target_charset.upper()
33
+ self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
34
+
35
+ def __call__(self, label):
36
+ if self.lowercase_only:
37
+ label = label.lower()
38
+ elif self.uppercase_only:
39
+ label = label.upper()
40
+ # Remove unsupported characters
41
+ label = self.unsupported.sub('', label)
42
+ return label
43
+
44
+
45
+ class BaseTokenizer(ABC):
46
+
47
+ def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
48
+ self._itos = specials_first + tuple(charset) + specials_last
49
+ self._stoi = {s: i for i, s in enumerate(self._itos)}
50
+
51
+ def __len__(self):
52
+ return len(self._itos)
53
+
54
+ def _tok2ids(self, tokens: str) -> list[int]:
55
+ return [self._stoi[s] for s in tokens]
56
+
57
+ def _ids2tok(self, token_ids: list[int], join: bool = True) -> str:
58
+ tokens = [self._itos[i] for i in token_ids]
59
+ return ''.join(tokens) if join else tokens
60
+
61
+ @abstractmethod
62
+ def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
63
+ """Encode a batch of labels to a representation suitable for the model.
64
+
65
+ Args:
66
+ labels: List of labels. Each can be of arbitrary length.
67
+ device: Create tensor on this device.
68
+
69
+ Returns:
70
+ Batched tensor representation padded to the max label length. Shape: N, L
71
+ """
72
+ raise NotImplementedError
73
+
74
+ @abstractmethod
75
+ def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
76
+ """Internal method which performs the necessary filtering prior to decoding."""
77
+ raise NotImplementedError
78
+
79
+ def decode(self, token_dists: Tensor, raw: bool = False) -> tuple[list[str], list[Tensor]]:
80
+ """Decode a batch of token distributions.
81
+
82
+ Args:
83
+ token_dists: softmax probabilities over the token distribution. Shape: N, L, C
84
+ raw: return unprocessed labels (will return list of list of strings)
85
+
86
+ Returns:
87
+ list of string labels (arbitrary length) and
88
+ their corresponding sequence probabilities as a list of Tensors
89
+ """
90
+ batch_tokens = []
91
+ batch_probs = []
92
+ for dist in token_dists:
93
+ probs, ids = dist.max(-1) # greedy selection
94
+ if not raw:
95
+ probs, ids = self._filter(probs, ids)
96
+ tokens = self._ids2tok(ids, not raw)
97
+ batch_tokens.append(tokens)
98
+ batch_probs.append(probs)
99
+ return batch_tokens, batch_probs
100
+
101
+
102
+ class Tokenizer(BaseTokenizer):
103
+ BOS = '[B]'
104
+ EOS = '[E]'
105
+ PAD = '[P]'
106
+
107
+ def __init__(self, charset: str) -> None:
108
+ specials_first = (self.EOS,)
109
+ specials_last = (self.BOS, self.PAD)
110
+ super().__init__(charset, specials_first, specials_last)
111
+ self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last]
112
+
113
+ def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
114
+ batch = [
115
+ torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
116
+ for y in labels
117
+ ]
118
+ return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
119
+
120
+ def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
121
+ ids = ids.tolist()
122
+ try:
123
+ eos_idx = ids.index(self.eos_id)
124
+ except ValueError:
125
+ eos_idx = len(ids) # Nothing to truncate.
126
+ # Truncate after EOS
127
+ ids = ids[:eos_idx]
128
+ probs = probs[: eos_idx + 1] # but include prob. for EOS (if it exists)
129
+ return probs, ids
130
+
131
+
132
+ class CTCTokenizer(BaseTokenizer):
133
+ BLANK = '[B]'
134
+
135
+ def __init__(self, charset: str) -> None:
136
+ # BLANK uses index == 0 by default
137
+ super().__init__(charset, specials_first=(self.BLANK,))
138
+ self.blank_id = self._stoi[self.BLANK]
139
+
140
+ def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
141
+ # We use a padded representation since we don't want to use CUDNN's CTC implementation
142
+ batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels]
143
+ return pad_sequence(batch, batch_first=True, padding_value=self.blank_id)
144
+
145
+ def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
146
+ # Best path decoding:
147
+ ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens
148
+ ids = [x for x in ids if x != self.blank_id] # Remove BLANKs
149
+ # `probs` is just pass-through since all positions are considered part of the path
150
+ return probs, ids
IndicPhotoOCR/utils/strhub/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # from .utils import load_from_checkpoint
IndicPhotoOCR/utils/strhub/models/abinet/LICENSE ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ABINet for non-commercial purposes
2
+
3
+ Copyright (c) 2021, USTC
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
IndicPhotoOCR/utils/strhub/models/abinet/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""
2
+ Fang, Shancheng, Hongtao, Xie, Yuxin, Wang, Zhendong, Mao, and Yongdong, Zhang.
3
+ "Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition." .
4
+ In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 7098-7107).2021.
5
+
6
+ https://arxiv.org/abs/2103.06495
7
+
8
+ All source files, except `system.py`, are based on the implementation listed below,
9
+ and hence are released under the license of the original.
10
+
11
+ Source: https://github.com/FangShancheng/ABINet
12
+ License: 2-clause BSD License (see included LICENSE file)
13
+ """
IndicPhotoOCR/utils/strhub/models/abinet/attention.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .transformer import PositionalEncoding
5
+
6
+
7
+ class Attention(nn.Module):
8
+ def __init__(self, in_channels=512, max_length=25, n_feature=256):
9
+ super().__init__()
10
+ self.max_length = max_length
11
+
12
+ self.f0_embedding = nn.Embedding(max_length, in_channels)
13
+ self.w0 = nn.Linear(max_length, n_feature)
14
+ self.wv = nn.Linear(in_channels, in_channels)
15
+ self.we = nn.Linear(in_channels, max_length)
16
+
17
+ self.active = nn.Tanh()
18
+ self.softmax = nn.Softmax(dim=2)
19
+
20
+ def forward(self, enc_output):
21
+ enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2)
22
+ reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device)
23
+ reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S)
24
+ reading_order_embed = self.f0_embedding(reading_order) # b,25,512
25
+
26
+ t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256
27
+ t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512
28
+
29
+ attn = self.we(t) # b,256,25
30
+ attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256
31
+ g_output = torch.bmm(attn, enc_output) # b,25,512
32
+ return g_output, attn.view(*attn.shape[:2], 8, 32)
33
+
34
+
35
+ def encoder_layer(in_c, out_c, k=3, s=2, p=1):
36
+ return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
37
+ nn.BatchNorm2d(out_c),
38
+ nn.ReLU(True))
39
+
40
+
41
+ def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None):
42
+ align_corners = None if mode == 'nearest' else True
43
+ return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor,
44
+ mode=mode, align_corners=align_corners),
45
+ nn.Conv2d(in_c, out_c, k, s, p),
46
+ nn.BatchNorm2d(out_c),
47
+ nn.ReLU(True))
48
+
49
+
50
+ class PositionAttention(nn.Module):
51
+ def __init__(self, max_length, in_channels=512, num_channels=64,
52
+ h=8, w=32, mode='nearest', **kwargs):
53
+ super().__init__()
54
+ self.max_length = max_length
55
+ self.k_encoder = nn.Sequential(
56
+ encoder_layer(in_channels, num_channels, s=(1, 2)),
57
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
58
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
59
+ encoder_layer(num_channels, num_channels, s=(2, 2))
60
+ )
61
+ self.k_decoder = nn.Sequential(
62
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
63
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
64
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
65
+ decoder_layer(num_channels, in_channels, size=(h, w), mode=mode)
66
+ )
67
+
68
+ self.pos_encoder = PositionalEncoding(in_channels, dropout=0., max_len=max_length)
69
+ self.project = nn.Linear(in_channels, in_channels)
70
+
71
+ def forward(self, x):
72
+ N, E, H, W = x.size()
73
+ k, v = x, x # (N, E, H, W)
74
+
75
+ # calculate key vector
76
+ features = []
77
+ for i in range(0, len(self.k_encoder)):
78
+ k = self.k_encoder[i](k)
79
+ features.append(k)
80
+ for i in range(0, len(self.k_decoder) - 1):
81
+ k = self.k_decoder[i](k)
82
+ k = k + features[len(self.k_decoder) - 2 - i]
83
+ k = self.k_decoder[-1](k)
84
+
85
+ # calculate query vector
86
+ # TODO q=f(q,k)
87
+ zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E)
88
+ q = self.pos_encoder(zeros) # (T, N, E)
89
+ q = q.permute(1, 0, 2) # (N, T, E)
90
+ q = self.project(q) # (N, T, E)
91
+
92
+ # calculate attention
93
+ attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
94
+ attn_scores = attn_scores / (E ** 0.5)
95
+ attn_scores = torch.softmax(attn_scores, dim=-1)
96
+
97
+ v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
98
+ attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
99
+
100
+ return attn_vecs, attn_scores.view(N, -1, H, W)
IndicPhotoOCR/utils/strhub/models/abinet/backbone.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.nn import TransformerEncoderLayer, TransformerEncoder
3
+
4
+ from .resnet import resnet45
5
+ from .transformer import PositionalEncoding
6
+
7
+
8
+ class ResTranformer(nn.Module):
9
+ def __init__(self, d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', backbone_ln=2):
10
+ super().__init__()
11
+ self.resnet = resnet45()
12
+ self.pos_encoder = PositionalEncoding(d_model, max_len=8 * 32)
13
+ encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead,
14
+ dim_feedforward=d_inner, dropout=dropout, activation=activation)
15
+ self.transformer = TransformerEncoder(encoder_layer, backbone_ln)
16
+
17
+ def forward(self, images):
18
+ feature = self.resnet(images)
19
+ n, c, h, w = feature.shape
20
+ feature = feature.view(n, c, -1).permute(2, 0, 1)
21
+ feature = self.pos_encoder(feature)
22
+ feature = self.transformer(feature)
23
+ feature = feature.permute(1, 2, 0).view(n, c, h, w)
24
+ return feature
IndicPhotoOCR/utils/strhub/models/abinet/model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Model(nn.Module):
6
+
7
+ def __init__(self, dataset_max_length: int, null_label: int):
8
+ super().__init__()
9
+ self.max_length = dataset_max_length + 1 # additional stop token
10
+ self.null_label = null_label
11
+
12
+ def _get_length(self, logit, dim=-1):
13
+ """ Greed decoder to obtain length from logit"""
14
+ out = (logit.argmax(dim=-1) == self.null_label)
15
+ abn = out.any(dim)
16
+ out = ((out.cumsum(dim) == 1) & out).max(dim)[1]
17
+ out = out + 1 # additional end token
18
+ out = torch.where(abn, out, out.new_tensor(logit.shape[1], device=out.device))
19
+ return out
20
+
21
+ @staticmethod
22
+ def _get_padding_mask(length, max_length):
23
+ length = length.unsqueeze(-1)
24
+ grid = torch.arange(0, max_length, device=length.device).unsqueeze(0)
25
+ return grid >= length
26
+
27
+ @staticmethod
28
+ def _get_location_mask(sz, device=None):
29
+ mask = torch.eye(sz, device=device)
30
+ mask = mask.float().masked_fill(mask == 1, float('-inf'))
31
+ return mask
IndicPhotoOCR/utils/strhub/models/abinet/model_abinet_iter.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .model_alignment import BaseAlignment
5
+ from .model_language import BCNLanguage
6
+ from .model_vision import BaseVision
7
+
8
+
9
+ class ABINetIterModel(nn.Module):
10
+ def __init__(self, dataset_max_length, null_label, num_classes, iter_size=1,
11
+ d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu',
12
+ v_loss_weight=1., v_attention='position', v_attention_mode='nearest',
13
+ v_backbone='transformer', v_num_layers=2,
14
+ l_loss_weight=1., l_num_layers=4, l_detach=True, l_use_self_attn=False,
15
+ a_loss_weight=1.):
16
+ super().__init__()
17
+ self.iter_size = iter_size
18
+ self.vision = BaseVision(dataset_max_length, null_label, num_classes, v_attention, v_attention_mode,
19
+ v_loss_weight, d_model, nhead, d_inner, dropout, activation, v_backbone, v_num_layers)
20
+ self.language = BCNLanguage(dataset_max_length, null_label, num_classes, d_model, nhead, d_inner, dropout,
21
+ activation, l_num_layers, l_detach, l_use_self_attn, l_loss_weight)
22
+ self.alignment = BaseAlignment(dataset_max_length, null_label, num_classes, d_model, a_loss_weight)
23
+
24
+ def forward(self, images):
25
+ v_res = self.vision(images)
26
+ a_res = v_res
27
+ all_l_res, all_a_res = [], []
28
+ for _ in range(self.iter_size):
29
+ tokens = torch.softmax(a_res['logits'], dim=-1)
30
+ lengths = a_res['pt_lengths']
31
+ lengths.clamp_(2, self.language.max_length) # TODO:move to langauge model
32
+ l_res = self.language(tokens, lengths)
33
+ all_l_res.append(l_res)
34
+ a_res = self.alignment(l_res['feature'], v_res['feature'])
35
+ all_a_res.append(a_res)
36
+ if self.training:
37
+ return all_a_res, all_l_res, v_res
38
+ else:
39
+ return a_res, all_l_res[-1], v_res
IndicPhotoOCR/utils/strhub/models/abinet/model_alignment.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .model import Model
5
+
6
+
7
+ class BaseAlignment(Model):
8
+ def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, loss_weight=1.0):
9
+ super().__init__(dataset_max_length, null_label)
10
+ self.loss_weight = loss_weight
11
+ self.w_att = nn.Linear(2 * d_model, d_model)
12
+ self.cls = nn.Linear(d_model, num_classes)
13
+
14
+ def forward(self, l_feature, v_feature):
15
+ """
16
+ Args:
17
+ l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
18
+ v_feature: (N, T, E) shape the same as l_feature
19
+ """
20
+ f = torch.cat((l_feature, v_feature), dim=2)
21
+ f_att = torch.sigmoid(self.w_att(f))
22
+ output = f_att * v_feature + (1 - f_att) * l_feature
23
+
24
+ logits = self.cls(output) # (N, T, C)
25
+ pt_lengths = self._get_length(logits)
26
+
27
+ return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight': self.loss_weight,
28
+ 'name': 'alignment'}
IndicPhotoOCR/utils/strhub/models/abinet/model_language.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .model import Model
4
+ from .transformer import PositionalEncoding, TransformerDecoderLayer, TransformerDecoder
5
+
6
+
7
+ class BCNLanguage(Model):
8
+ def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, nhead=8, d_inner=2048, dropout=0.1,
9
+ activation='relu', num_layers=4, detach=True, use_self_attn=False, loss_weight=1.0,
10
+ global_debug=False):
11
+ super().__init__(dataset_max_length, null_label)
12
+ self.detach = detach
13
+ self.loss_weight = loss_weight
14
+ self.proj = nn.Linear(num_classes, d_model, False)
15
+ self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length)
16
+ self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length)
17
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout,
18
+ activation, self_attn=use_self_attn, debug=global_debug)
19
+ self.model = TransformerDecoder(decoder_layer, num_layers)
20
+ self.cls = nn.Linear(d_model, num_classes)
21
+
22
+ def forward(self, tokens, lengths):
23
+ """
24
+ Args:
25
+ tokens: (N, T, C) where T is length, N is batch size and C is classes number
26
+ lengths: (N,)
27
+ """
28
+ if self.detach:
29
+ tokens = tokens.detach()
30
+ embed = self.proj(tokens) # (N, T, E)
31
+ embed = embed.permute(1, 0, 2) # (T, N, E)
32
+ embed = self.token_encoder(embed) # (T, N, E)
33
+ padding_mask = self._get_padding_mask(lengths, self.max_length)
34
+
35
+ zeros = embed.new_zeros(*embed.shape)
36
+ qeury = self.pos_encoder(zeros)
37
+ location_mask = self._get_location_mask(self.max_length, tokens.device)
38
+ output = self.model(qeury, embed,
39
+ tgt_key_padding_mask=padding_mask,
40
+ memory_mask=location_mask,
41
+ memory_key_padding_mask=padding_mask) # (T, N, E)
42
+ output = output.permute(1, 0, 2) # (N, T, E)
43
+
44
+ logits = self.cls(output) # (N, T, C)
45
+ pt_lengths = self._get_length(logits)
46
+
47
+ res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths,
48
+ 'loss_weight': self.loss_weight, 'name': 'language'}
49
+ return res
IndicPhotoOCR/utils/strhub/models/abinet/model_vision.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from .attention import PositionAttention, Attention
4
+ from .backbone import ResTranformer
5
+ from .model import Model
6
+ from .resnet import resnet45
7
+
8
+
9
+ class BaseVision(Model):
10
+ def __init__(self, dataset_max_length, null_label, num_classes,
11
+ attention='position', attention_mode='nearest', loss_weight=1.0,
12
+ d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu',
13
+ backbone='transformer', backbone_ln=2):
14
+ super().__init__(dataset_max_length, null_label)
15
+ self.loss_weight = loss_weight
16
+ self.out_channels = d_model
17
+
18
+ if backbone == 'transformer':
19
+ self.backbone = ResTranformer(d_model, nhead, d_inner, dropout, activation, backbone_ln)
20
+ else:
21
+ self.backbone = resnet45()
22
+
23
+ if attention == 'position':
24
+ self.attention = PositionAttention(
25
+ max_length=self.max_length,
26
+ mode=attention_mode
27
+ )
28
+ elif attention == 'attention':
29
+ self.attention = Attention(
30
+ max_length=self.max_length,
31
+ n_feature=8 * 32,
32
+ )
33
+ else:
34
+ raise ValueError(f'invalid attention: {attention}')
35
+
36
+ self.cls = nn.Linear(self.out_channels, num_classes)
37
+
38
+ def forward(self, images):
39
+ features = self.backbone(images) # (N, E, H, W)
40
+ attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W)
41
+ logits = self.cls(attn_vecs) # (N, T, C)
42
+ pt_lengths = self._get_length(logits)
43
+
44
+ return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths,
45
+ 'attn_scores': attn_scores, 'loss_weight': self.loss_weight, 'name': 'vision'}
IndicPhotoOCR/utils/strhub/models/abinet/resnet.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Callable
3
+
4
+ import torch.nn as nn
5
+ from torchvision.models import resnet
6
+
7
+
8
+ class BasicBlock(resnet.BasicBlock):
9
+
10
+ def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None,
11
+ groups: int = 1, base_width: int = 64, dilation: int = 1,
12
+ norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
13
+ super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)
14
+ self.conv1 = resnet.conv1x1(inplanes, planes)
15
+ self.conv2 = resnet.conv3x3(planes, planes, stride)
16
+
17
+
18
+ class ResNet(nn.Module):
19
+
20
+ def __init__(self, block, layers):
21
+ super().__init__()
22
+ self.inplanes = 32
23
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1,
24
+ bias=False)
25
+ self.bn1 = nn.BatchNorm2d(32)
26
+ self.relu = nn.ReLU(inplace=True)
27
+
28
+ self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
29
+ self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
30
+ self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
31
+ self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
32
+ self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
33
+
34
+ for m in self.modules():
35
+ if isinstance(m, nn.Conv2d):
36
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
37
+ m.weight.data.normal_(0, math.sqrt(2. / n))
38
+ elif isinstance(m, nn.BatchNorm2d):
39
+ m.weight.data.fill_(1)
40
+ m.bias.data.zero_()
41
+
42
+ def _make_layer(self, block, planes, blocks, stride=1):
43
+ downsample = None
44
+ if stride != 1 or self.inplanes != planes * block.expansion:
45
+ downsample = nn.Sequential(
46
+ nn.Conv2d(self.inplanes, planes * block.expansion,
47
+ kernel_size=1, stride=stride, bias=False),
48
+ nn.BatchNorm2d(planes * block.expansion),
49
+ )
50
+
51
+ layers = []
52
+ layers.append(block(self.inplanes, planes, stride, downsample))
53
+ self.inplanes = planes * block.expansion
54
+ for i in range(1, blocks):
55
+ layers.append(block(self.inplanes, planes))
56
+
57
+ return nn.Sequential(*layers)
58
+
59
+ def forward(self, x):
60
+ x = self.conv1(x)
61
+ x = self.bn1(x)
62
+ x = self.relu(x)
63
+ x = self.layer1(x)
64
+ x = self.layer2(x)
65
+ x = self.layer3(x)
66
+ x = self.layer4(x)
67
+ x = self.layer5(x)
68
+ return x
69
+
70
+
71
+ def resnet45():
72
+ return ResNet(BasicBlock, [3, 4, 6, 6, 3])
IndicPhotoOCR/utils/strhub/models/abinet/system.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import math
18
+ from typing import Any, Optional
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import Tensor, nn
23
+ from torch.optim import AdamW
24
+ from torch.optim.lr_scheduler import OneCycleLR
25
+
26
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
27
+ from timm.optim.optim_factory import param_groups_weight_decay
28
+
29
+ from strhub.models.base import CrossEntropySystem
30
+ from strhub.models.utils import init_weights
31
+
32
+ from .model_abinet_iter import ABINetIterModel as Model
33
+
34
+ log = logging.getLogger(__name__)
35
+
36
+
37
+ class ABINet(CrossEntropySystem):
38
+
39
+ def __init__(
40
+ self,
41
+ charset_train: str,
42
+ charset_test: str,
43
+ max_label_length: int,
44
+ batch_size: int,
45
+ lr: float,
46
+ warmup_pct: float,
47
+ weight_decay: float,
48
+ iter_size: int,
49
+ d_model: int,
50
+ nhead: int,
51
+ d_inner: int,
52
+ dropout: float,
53
+ activation: str,
54
+ v_loss_weight: float,
55
+ v_attention: str,
56
+ v_attention_mode: str,
57
+ v_backbone: str,
58
+ v_num_layers: int,
59
+ l_loss_weight: float,
60
+ l_num_layers: int,
61
+ l_detach: bool,
62
+ l_use_self_attn: bool,
63
+ l_lr: float,
64
+ a_loss_weight: float,
65
+ lm_only: bool = False,
66
+ **kwargs,
67
+ ) -> None:
68
+ super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
69
+ self.scheduler = None
70
+ self.save_hyperparameters()
71
+ self.max_label_length = max_label_length
72
+ self.num_classes = len(self.tokenizer) - 2 # We don't predict <bos> nor <pad>
73
+ self.model = Model(
74
+ max_label_length,
75
+ self.eos_id,
76
+ self.num_classes,
77
+ iter_size,
78
+ d_model,
79
+ nhead,
80
+ d_inner,
81
+ dropout,
82
+ activation,
83
+ v_loss_weight,
84
+ v_attention,
85
+ v_attention_mode,
86
+ v_backbone,
87
+ v_num_layers,
88
+ l_loss_weight,
89
+ l_num_layers,
90
+ l_detach,
91
+ l_use_self_attn,
92
+ a_loss_weight,
93
+ )
94
+ self.model.apply(init_weights)
95
+ # FIXME: doesn't support resumption from checkpoint yet
96
+ self._reset_alignment = True
97
+ self._reset_optimizers = True
98
+ self.l_lr = l_lr
99
+ self.lm_only = lm_only
100
+ # Train LM only. Freeze other submodels.
101
+ if lm_only:
102
+ self.l_lr = lr # for tuning
103
+ self.model.vision.requires_grad_(False)
104
+ self.model.alignment.requires_grad_(False)
105
+
106
+ @property
107
+ def _pretraining(self):
108
+ # In the original work, VM was pretrained for 8 epochs while full model was trained for an additional 10 epochs.
109
+ total_steps = self.trainer.estimated_stepping_batches * self.trainer.accumulate_grad_batches
110
+ return self.global_step < (8 / (8 + 10)) * total_steps
111
+
112
+ @torch.jit.ignore
113
+ def no_weight_decay(self):
114
+ return {'model.language.proj.weight'}
115
+
116
+ def _add_weight_decay(self, model: nn.Module, skip_list=()):
117
+ if self.weight_decay:
118
+ return param_groups_weight_decay(model, self.weight_decay, skip_list)
119
+ else:
120
+ return [{'params': model.parameters()}]
121
+
122
+ def configure_optimizers(self):
123
+ agb = self.trainer.accumulate_grad_batches
124
+ # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
125
+ lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.0
126
+ lr = lr_scale * self.lr
127
+ l_lr = lr_scale * self.l_lr
128
+ params = []
129
+ params.extend(self._add_weight_decay(self.model.vision))
130
+ params.extend(self._add_weight_decay(self.model.alignment))
131
+ # We use a different learning rate for the LM.
132
+ for p in self._add_weight_decay(self.model.language, ('proj.weight',)):
133
+ p['lr'] = l_lr
134
+ params.append(p)
135
+ max_lr = [p.get('lr', lr) for p in params]
136
+ optim = AdamW(params, lr)
137
+ self.scheduler = OneCycleLR(
138
+ optim, max_lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, cycle_momentum=False
139
+ )
140
+ return {'optimizer': optim, 'lr_scheduler': {'scheduler': self.scheduler, 'interval': 'step'}}
141
+
142
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
143
+ max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
144
+ logits = self.model.forward(images)[0]['logits']
145
+ return logits[:, : max_length + 1] # truncate
146
+
147
+ def calc_loss(self, targets, *res_lists) -> Tensor:
148
+ total_loss = 0
149
+ for res_list in res_lists:
150
+ loss = 0
151
+ if isinstance(res_list, dict):
152
+ res_list = [res_list]
153
+ for res in res_list:
154
+ logits = res['logits'].flatten(end_dim=1)
155
+ loss += F.cross_entropy(logits, targets.flatten(), ignore_index=self.pad_id)
156
+ loss /= len(res_list)
157
+ self.log('loss_' + res_list[0]['name'], loss)
158
+ total_loss += res_list[0]['loss_weight'] * loss
159
+ return total_loss
160
+
161
+ def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
162
+ if not self._pretraining and self._reset_optimizers:
163
+ log.info('Pretraining ends. Updating base LRs.')
164
+ self._reset_optimizers = False
165
+ # Make base_lr the same for all groups
166
+ base_lr = self.scheduler.base_lrs[0] # base_lr of group 0 - VM
167
+ self.scheduler.base_lrs = [base_lr] * len(self.scheduler.base_lrs)
168
+
169
+ def _prepare_inputs_and_targets(self, labels):
170
+ # Use dummy label to ensure sequence length is constant.
171
+ dummy = ['0' * self.max_label_length]
172
+ targets = self.tokenizer.encode(dummy + list(labels), self.device)[1:]
173
+ targets = targets[:, 1:] # remove <bos>. Unused here.
174
+ # Inputs are padded with eos_id
175
+ inputs = torch.where(targets == self.pad_id, self.eos_id, targets)
176
+ inputs = F.one_hot(inputs, self.num_classes).float()
177
+ lengths = torch.as_tensor(list(map(len, labels)), device=self.device) + 1 # +1 for eos
178
+ return inputs, lengths, targets
179
+
180
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
181
+ images, labels = batch
182
+ inputs, lengths, targets = self._prepare_inputs_and_targets(labels)
183
+ if self.lm_only:
184
+ l_res = self.model.language(inputs, lengths)
185
+ loss = self.calc_loss(targets, l_res)
186
+ # Pretrain submodels independently first
187
+ elif self._pretraining:
188
+ # Vision
189
+ v_res = self.model.vision(images)
190
+ # Language
191
+ l_res = self.model.language(inputs, lengths)
192
+ # We also train the alignment model to 'satisfy' DDP requirements (all parameters should be used).
193
+ # We'll reset its parameters prior to joint training.
194
+ a_res = self.model.alignment(l_res['feature'].detach(), v_res['feature'].detach())
195
+ loss = self.calc_loss(targets, v_res, l_res, a_res)
196
+ else:
197
+ # Reset alignment model's parameters once prior to full model training.
198
+ if self._reset_alignment:
199
+ log.info('Pretraining ends. Resetting alignment model.')
200
+ self._reset_alignment = False
201
+ self.model.alignment.apply(init_weights)
202
+ all_a_res, all_l_res, v_res = self.model.forward(images)
203
+ loss = self.calc_loss(targets, v_res, all_l_res, all_a_res)
204
+ self.log('loss', loss)
205
+ return loss
206
+
207
+ def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
208
+ if self.lm_only:
209
+ inputs, lengths, targets = self._prepare_inputs_and_targets(labels)
210
+ l_res = self.model.language(inputs, lengths)
211
+ loss = self.calc_loss(targets, l_res)
212
+ loss_numel = (targets != self.pad_id).sum()
213
+ return l_res['logits'], loss, loss_numel
214
+ else:
215
+ return super().forward_logits_loss(images, labels)
IndicPhotoOCR/utils/strhub/models/abinet/transformer.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from torch.nn.modules.transformer import _get_activation_fn, _get_clones
7
+
8
+
9
+ class TransformerDecoder(nn.Module):
10
+ r"""TransformerDecoder is a stack of N decoder layers
11
+
12
+ Args:
13
+ decoder_layer: an instance of the TransformerDecoderLayer() class (required).
14
+ num_layers: the number of sub-decoder-layers in the decoder (required).
15
+ norm: the layer normalization component (optional).
16
+
17
+ Examples::
18
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
19
+ >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
20
+ >>> memory = torch.rand(10, 32, 512)
21
+ >>> tgt = torch.rand(20, 32, 512)
22
+ >>> out = transformer_decoder(tgt, memory)
23
+ """
24
+ __constants__ = ['norm']
25
+
26
+ def __init__(self, decoder_layer, num_layers, norm=None):
27
+ super(TransformerDecoder, self).__init__()
28
+ self.layers = _get_clones(decoder_layer, num_layers)
29
+ self.num_layers = num_layers
30
+ self.norm = norm
31
+
32
+ def forward(self, tgt, memory, memory2=None, tgt_mask=None,
33
+ memory_mask=None, memory_mask2=None, tgt_key_padding_mask=None,
34
+ memory_key_padding_mask=None, memory_key_padding_mask2=None):
35
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
36
+ r"""Pass the inputs (and mask) through the decoder layer in turn.
37
+
38
+ Args:
39
+ tgt: the sequence to the decoder (required).
40
+ memory: the sequence from the last layer of the encoder (required).
41
+ tgt_mask: the mask for the tgt sequence (optional).
42
+ memory_mask: the mask for the memory sequence (optional).
43
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
44
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
45
+
46
+ Shape:
47
+ see the docs in Transformer class.
48
+ """
49
+ output = tgt
50
+
51
+ for mod in self.layers:
52
+ output = mod(output, memory, memory2=memory2, tgt_mask=tgt_mask,
53
+ memory_mask=memory_mask, memory_mask2=memory_mask2,
54
+ tgt_key_padding_mask=tgt_key_padding_mask,
55
+ memory_key_padding_mask=memory_key_padding_mask,
56
+ memory_key_padding_mask2=memory_key_padding_mask2)
57
+
58
+ if self.norm is not None:
59
+ output = self.norm(output)
60
+
61
+ return output
62
+
63
+
64
+ class TransformerDecoderLayer(nn.Module):
65
+ r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
66
+ This standard decoder layer is based on the paper "Attention Is All You Need".
67
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
68
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
69
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
70
+ in a different way during application.
71
+
72
+ Args:
73
+ d_model: the number of expected features in the input (required).
74
+ nhead: the number of heads in the multiheadattention models (required).
75
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
76
+ dropout: the dropout value (default=0.1).
77
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
78
+
79
+ Examples::
80
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
81
+ >>> memory = torch.rand(10, 32, 512)
82
+ >>> tgt = torch.rand(20, 32, 512)
83
+ >>> out = decoder_layer(tgt, memory)
84
+ """
85
+
86
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
87
+ activation="relu", self_attn=True, siamese=False, debug=False):
88
+ super().__init__()
89
+ self.has_self_attn, self.siamese = self_attn, siamese
90
+ self.debug = debug
91
+ if self.has_self_attn:
92
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
93
+ self.norm1 = nn.LayerNorm(d_model)
94
+ self.dropout1 = nn.Dropout(dropout)
95
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
96
+ # Implementation of Feedforward model
97
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
98
+ self.dropout = nn.Dropout(dropout)
99
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
100
+
101
+ self.norm2 = nn.LayerNorm(d_model)
102
+ self.norm3 = nn.LayerNorm(d_model)
103
+ self.dropout2 = nn.Dropout(dropout)
104
+ self.dropout3 = nn.Dropout(dropout)
105
+ if self.siamese:
106
+ self.multihead_attn2 = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
107
+
108
+ self.activation = _get_activation_fn(activation)
109
+
110
+ def __setstate__(self, state):
111
+ if 'activation' not in state:
112
+ state['activation'] = F.relu
113
+ super().__setstate__(state)
114
+
115
+ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
116
+ tgt_key_padding_mask=None, memory_key_padding_mask=None,
117
+ memory2=None, memory_mask2=None, memory_key_padding_mask2=None):
118
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
119
+ r"""Pass the inputs (and mask) through the decoder layer.
120
+
121
+ Args:
122
+ tgt: the sequence to the decoder layer (required).
123
+ memory: the sequence from the last layer of the encoder (required).
124
+ tgt_mask: the mask for the tgt sequence (optional).
125
+ memory_mask: the mask for the memory sequence (optional).
126
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
127
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
128
+
129
+ Shape:
130
+ see the docs in Transformer class.
131
+ """
132
+ if self.has_self_attn:
133
+ tgt2, attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
134
+ key_padding_mask=tgt_key_padding_mask)
135
+ tgt = tgt + self.dropout1(tgt2)
136
+ tgt = self.norm1(tgt)
137
+ if self.debug: self.attn = attn
138
+ tgt2, attn2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
139
+ key_padding_mask=memory_key_padding_mask)
140
+ if self.debug: self.attn2 = attn2
141
+
142
+ if self.siamese:
143
+ tgt3, attn3 = self.multihead_attn2(tgt, memory2, memory2, attn_mask=memory_mask2,
144
+ key_padding_mask=memory_key_padding_mask2)
145
+ tgt = tgt + self.dropout2(tgt3)
146
+ if self.debug: self.attn3 = attn3
147
+
148
+ tgt = tgt + self.dropout2(tgt2)
149
+ tgt = self.norm2(tgt)
150
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
151
+ tgt = tgt + self.dropout3(tgt2)
152
+ tgt = self.norm3(tgt)
153
+
154
+ return tgt
155
+
156
+
157
+ class PositionalEncoding(nn.Module):
158
+ r"""Inject some information about the relative or absolute position of the tokens
159
+ in the sequence. The positional encodings have the same dimension as
160
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
161
+ functions of different frequencies.
162
+ .. math::
163
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
164
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
165
+ \text{where pos is the word position and i is the embed idx)
166
+ Args:
167
+ d_model: the embed dim (required).
168
+ dropout: the dropout value (default=0.1).
169
+ max_len: the max. length of the incoming sequence (default=5000).
170
+ Examples:
171
+ >>> pos_encoder = PositionalEncoding(d_model)
172
+ """
173
+
174
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
175
+ super().__init__()
176
+ self.dropout = nn.Dropout(p=dropout)
177
+
178
+ pe = torch.zeros(max_len, d_model)
179
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
180
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
181
+ pe[:, 0::2] = torch.sin(position * div_term)
182
+ pe[:, 1::2] = torch.cos(position * div_term)
183
+ pe = pe.unsqueeze(0).transpose(0, 1)
184
+ self.register_buffer('pe', pe)
185
+
186
+ def forward(self, x):
187
+ r"""Inputs of forward function
188
+ Args:
189
+ x: the sequence fed to the positional encoder model (required).
190
+ Shape:
191
+ x: [sequence length, batch size, embed dim]
192
+ output: [sequence length, batch size, embed dim]
193
+ Examples:
194
+ >>> output = pos_encoder(x)
195
+ """
196
+
197
+ x = x + self.pe[:x.size(0), :]
198
+ return self.dropout(x)
IndicPhotoOCR/utils/strhub/models/base.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from abc import ABC, abstractmethod
18
+ from dataclasses import dataclass
19
+ from typing import Optional
20
+
21
+ from nltk import edit_distance
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from torch import Tensor
26
+ from torch.optim import Optimizer
27
+ from torch.optim.lr_scheduler import OneCycleLR
28
+
29
+ import pytorch_lightning as pl
30
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
31
+ from timm.optim import create_optimizer_v2
32
+
33
+ from IndicPhotoOCR.utils.strhub.data.utils import BaseTokenizer, CharsetAdapter, CTCTokenizer, Tokenizer
34
+
35
+
36
+ @dataclass
37
+ class BatchResult:
38
+ num_samples: int
39
+ correct: int
40
+ ned: float
41
+ confidence: float
42
+ label_length: int
43
+ loss: Tensor
44
+ loss_numel: int
45
+
46
+
47
+ EPOCH_OUTPUT = list[dict[str, BatchResult]]
48
+
49
+
50
+ class BaseSystem(pl.LightningModule, ABC):
51
+
52
+ def __init__(
53
+ self,
54
+ tokenizer: BaseTokenizer,
55
+ charset_test: str,
56
+ batch_size: int,
57
+ lr: float,
58
+ warmup_pct: float,
59
+ weight_decay: float,
60
+ ) -> None:
61
+ super().__init__()
62
+ self.tokenizer = tokenizer
63
+ self.charset_adapter = CharsetAdapter(charset_test)
64
+ self.batch_size = batch_size
65
+ self.lr = lr
66
+ self.warmup_pct = warmup_pct
67
+ self.weight_decay = weight_decay
68
+ self.outputs: EPOCH_OUTPUT = []
69
+
70
+ @abstractmethod
71
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
72
+ """Inference
73
+
74
+ Args:
75
+ images: Batch of images. Shape: N, Ch, H, W
76
+ max_length: Max sequence length of the output. If None, will use default.
77
+
78
+ Returns:
79
+ logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
80
+ """
81
+ raise NotImplementedError
82
+
83
+ @abstractmethod
84
+ def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
85
+ """Like forward(), but also computes the loss (calls forward() internally).
86
+
87
+ Args:
88
+ images: Batch of images. Shape: N, Ch, H, W
89
+ labels: Text labels of the images
90
+
91
+ Returns:
92
+ logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
93
+ loss: mean loss for the batch
94
+ loss_numel: number of elements the loss was calculated from
95
+ """
96
+ raise NotImplementedError
97
+
98
+ def configure_optimizers(self):
99
+ agb = self.trainer.accumulate_grad_batches
100
+ # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
101
+ lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.0
102
+ lr = lr_scale * self.lr
103
+ optim = create_optimizer_v2(self, 'adamw', lr, self.weight_decay)
104
+ sched = OneCycleLR(
105
+ optim, lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, cycle_momentum=False
106
+ )
107
+ return {'optimizer': optim, 'lr_scheduler': {'scheduler': sched, 'interval': 'step'}}
108
+
109
+ def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer) -> None:
110
+ optimizer.zero_grad(set_to_none=True)
111
+
112
+ def _eval_step(self, batch, validation: bool) -> Optional[STEP_OUTPUT]:
113
+ images, labels = batch
114
+
115
+ correct = 0
116
+ total = 0
117
+ ned = 0
118
+ confidence = 0
119
+ label_length = 0
120
+ if validation:
121
+ logits, loss, loss_numel = self.forward_logits_loss(images, labels)
122
+ else:
123
+ # At test-time, we shouldn't specify a max_label_length because the test-time charset used
124
+ # might be different from the train-time charset. max_label_length in eval_logits_loss() is computed
125
+ # based on the transformed label, which could be wrong if the actual gt label contains characters existing
126
+ # in the train-time charset but not in the test-time charset. For example, "aishahaleyes.blogspot.com"
127
+ # is exactly 25 characters, but if processed by CharsetAdapter for the 36-char set, it becomes 23 characters
128
+ # long only, which sets max_label_length = 23. This will cause the model prediction to be truncated.
129
+ logits = self.forward(images)
130
+ loss = loss_numel = None # Only used for validation; not needed at test-time.
131
+
132
+ probs = logits.softmax(-1)
133
+ preds, probs = self.tokenizer.decode(probs)
134
+ for pred, prob, gt in zip(preds, probs, labels):
135
+ confidence += prob.prod().item()
136
+ pred = self.charset_adapter(pred)
137
+ # Follow ICDAR 2019 definition of N.E.D.
138
+ ned += edit_distance(pred, gt) / max(len(pred), len(gt))
139
+ if pred == gt:
140
+ correct += 1
141
+ total += 1
142
+ label_length += len(pred)
143
+ return dict(output=BatchResult(total, correct, ned, confidence, label_length, loss, loss_numel))
144
+
145
+ @staticmethod
146
+ def _aggregate_results(outputs: EPOCH_OUTPUT) -> tuple[float, float, float]:
147
+ if not outputs:
148
+ return 0.0, 0.0, 0.0
149
+ total_loss = 0
150
+ total_loss_numel = 0
151
+ total_n_correct = 0
152
+ total_norm_ED = 0
153
+ total_size = 0
154
+ for result in outputs:
155
+ result = result['output']
156
+ total_loss += result.loss_numel * result.loss
157
+ total_loss_numel += result.loss_numel
158
+ total_n_correct += result.correct
159
+ total_norm_ED += result.ned
160
+ total_size += result.num_samples
161
+ acc = total_n_correct / total_size
162
+ ned = 1 - total_norm_ED / total_size
163
+ loss = total_loss / total_loss_numel
164
+ return acc, ned, loss
165
+
166
+ def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
167
+ result = self._eval_step(batch, True)
168
+ self.outputs.append(result)
169
+ return result
170
+
171
+ def on_validation_epoch_end(self) -> None:
172
+ acc, ned, loss = self._aggregate_results(self.outputs)
173
+ self.outputs.clear()
174
+ self.log('val_accuracy', 100 * acc, sync_dist=True)
175
+ self.log('val_NED', 100 * ned, sync_dist=True)
176
+ self.log('val_loss', loss, sync_dist=True)
177
+ self.log('hp_metric', acc, sync_dist=True)
178
+
179
+ def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
180
+ return self._eval_step(batch, False)
181
+
182
+
183
+ class CrossEntropySystem(BaseSystem):
184
+
185
+ def __init__(
186
+ self, charset_train: str, charset_test: str, batch_size: int, lr: float, warmup_pct: float, weight_decay: float
187
+ ) -> None:
188
+ tokenizer = Tokenizer(charset_train)
189
+ super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
190
+ self.bos_id = tokenizer.bos_id
191
+ self.eos_id = tokenizer.eos_id
192
+ self.pad_id = tokenizer.pad_id
193
+
194
+ def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
195
+ targets = self.tokenizer.encode(labels, self.device)
196
+ targets = targets[:, 1:] # Discard <bos>
197
+ max_len = targets.shape[1] - 1 # exclude <eos> from count
198
+ logits = self.forward(images, max_len)
199
+ loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id)
200
+ loss_numel = (targets != self.pad_id).sum()
201
+ return logits, loss, loss_numel
202
+
203
+
204
+ class CTCSystem(BaseSystem):
205
+
206
+ def __init__(
207
+ self, charset_train: str, charset_test: str, batch_size: int, lr: float, warmup_pct: float, weight_decay: float
208
+ ) -> None:
209
+ tokenizer = CTCTokenizer(charset_train)
210
+ super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
211
+ self.blank_id = tokenizer.blank_id
212
+
213
+ def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
214
+ targets = self.tokenizer.encode(labels, self.device)
215
+ logits = self.forward(images)
216
+ log_probs = logits.log_softmax(-1).transpose(0, 1) # swap batch and seq. dims
217
+ T, N, _ = log_probs.shape
218
+ input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long, device=self.device)
219
+ target_lengths = torch.as_tensor(list(map(len, labels)), dtype=torch.long, device=self.device)
220
+ loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_id, zero_infinity=True)
221
+ return logits, loss, N
IndicPhotoOCR/utils/strhub/models/crnn/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2017 Jieru Mei <meijieru@gmail.com>
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
IndicPhotoOCR/utils/strhub/models/crnn/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""
2
+ Shi, Baoguang, Xiang Bai, and Cong Yao.
3
+ "An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition."
4
+ IEEE transactions on pattern analysis and machine intelligence 39, no. 11 (2016): 2298-2304.
5
+
6
+ https://arxiv.org/abs/1507.05717
7
+
8
+ All source files, except `system.py`, are based on the implementation listed below,
9
+ and hence are released under the license of the original.
10
+
11
+ Source: https://github.com/meijieru/crnn.pytorch
12
+ License: MIT License (see included LICENSE file)
13
+ """
IndicPhotoOCR/utils/strhub/models/crnn/model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from strhub.models.modules import BidirectionalLSTM
4
+
5
+
6
+ class CRNN(nn.Module):
7
+
8
+ def __init__(self, img_h, nc, nclass, nh, leaky_relu=False):
9
+ super().__init__()
10
+ assert img_h % 16 == 0, 'img_h has to be a multiple of 16'
11
+
12
+ ks = [3, 3, 3, 3, 3, 3, 2]
13
+ ps = [1, 1, 1, 1, 1, 1, 0]
14
+ ss = [1, 1, 1, 1, 1, 1, 1]
15
+ nm = [64, 128, 256, 256, 512, 512, 512]
16
+
17
+ cnn = nn.Sequential()
18
+
19
+ def convRelu(i, batchNormalization=False):
20
+ nIn = nc if i == 0 else nm[i - 1]
21
+ nOut = nm[i]
22
+ cnn.add_module(f'conv{i}',
23
+ nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i], bias=not batchNormalization))
24
+ if batchNormalization:
25
+ cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(nOut))
26
+ if leaky_relu:
27
+ cnn.add_module(f'relu{i}',
28
+ nn.LeakyReLU(0.2, inplace=True))
29
+ else:
30
+ cnn.add_module(f'relu{i}', nn.ReLU(True))
31
+
32
+ convRelu(0)
33
+ cnn.add_module('pooling0', nn.MaxPool2d(2, 2)) # 64x16x64
34
+ convRelu(1)
35
+ cnn.add_module('pooling1', nn.MaxPool2d(2, 2)) # 128x8x32
36
+ convRelu(2, True)
37
+ convRelu(3)
38
+ cnn.add_module('pooling2',
39
+ nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
40
+ convRelu(4, True)
41
+ convRelu(5)
42
+ cnn.add_module('pooling3',
43
+ nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
44
+ convRelu(6, True) # 512x1x16
45
+
46
+ self.cnn = cnn
47
+ self.rnn = nn.Sequential(
48
+ BidirectionalLSTM(512, nh, nh),
49
+ BidirectionalLSTM(nh, nh, nclass))
50
+
51
+ def forward(self, input):
52
+ # conv features
53
+ conv = self.cnn(input)
54
+ b, c, h, w = conv.size()
55
+ assert h == 1, 'the height of conv must be 1'
56
+ conv = conv.squeeze(2)
57
+ conv = conv.transpose(1, 2) # [b, w, c]
58
+
59
+ # rnn features
60
+ output = self.rnn(conv)
61
+
62
+ return output
IndicPhotoOCR/utils/strhub/models/crnn/system.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Sequence
17
+
18
+ from torch import Tensor
19
+
20
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
21
+
22
+ from strhub.models.base import CTCSystem
23
+ from strhub.models.utils import init_weights
24
+
25
+ from .model import CRNN as Model
26
+
27
+
28
+ class CRNN(CTCSystem):
29
+
30
+ def __init__(
31
+ self,
32
+ charset_train: str,
33
+ charset_test: str,
34
+ max_label_length: int,
35
+ batch_size: int,
36
+ lr: float,
37
+ warmup_pct: float,
38
+ weight_decay: float,
39
+ img_size: Sequence[int],
40
+ hidden_size: int,
41
+ leaky_relu: bool,
42
+ **kwargs,
43
+ ) -> None:
44
+ super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
45
+ self.save_hyperparameters()
46
+ self.model = Model(img_size[0], 3, len(self.tokenizer), hidden_size, leaky_relu)
47
+ self.model.apply(init_weights)
48
+
49
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
50
+ return self.model.forward(images)
51
+
52
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
53
+ images, labels = batch
54
+ loss = self.forward_logits_loss(images, labels)[1]
55
+ self.log('loss', loss)
56
+ return loss
IndicPhotoOCR/utils/strhub/models/modules.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Shared modules used by CRNN and TRBA"""
2
+ from torch import nn
3
+
4
+
5
+ class BidirectionalLSTM(nn.Module):
6
+ """Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py"""
7
+
8
+ def __init__(self, input_size, hidden_size, output_size):
9
+ super().__init__()
10
+ self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
11
+ self.linear = nn.Linear(hidden_size * 2, output_size)
12
+
13
+ def forward(self, input):
14
+ """
15
+ input : visual feature [batch_size x T x input_size], T = num_steps.
16
+ output : contextual feature [batch_size x T x output_size]
17
+ """
18
+ recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
19
+ output = self.linear(recurrent) # batch_size x T x output_size
20
+ return output
IndicPhotoOCR/utils/strhub/models/parseq/__init__.py ADDED
File without changes
IndicPhotoOCR/utils/strhub/models/parseq/model.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ from typing import Optional, Sequence
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch import Tensor
22
+
23
+ from timm.models.helpers import named_apply
24
+
25
+ from IndicPhotoOCR.utils.strhub.data.utils import Tokenizer
26
+ from IndicPhotoOCR.utils.strhub.models.utils import init_weights
27
+
28
+ from .modules import Decoder, DecoderLayer, Encoder, TokenEmbedding
29
+
30
+
31
+ class PARSeq(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ num_tokens: int,
36
+ max_label_length: int,
37
+ img_size: Sequence[int],
38
+ patch_size: Sequence[int],
39
+ embed_dim: int,
40
+ enc_num_heads: int,
41
+ enc_mlp_ratio: int,
42
+ enc_depth: int,
43
+ dec_num_heads: int,
44
+ dec_mlp_ratio: int,
45
+ dec_depth: int,
46
+ decode_ar: bool,
47
+ refine_iters: int,
48
+ dropout: float,
49
+ ) -> None:
50
+ super().__init__()
51
+
52
+ self.max_label_length = max_label_length
53
+ self.decode_ar = decode_ar
54
+ self.refine_iters = refine_iters
55
+
56
+ self.encoder = Encoder(
57
+ img_size, patch_size, embed_dim=embed_dim, depth=enc_depth, num_heads=enc_num_heads, mlp_ratio=enc_mlp_ratio
58
+ )
59
+ decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
60
+ self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=nn.LayerNorm(embed_dim))
61
+
62
+ # We don't predict <bos> nor <pad>
63
+ self.head = nn.Linear(embed_dim, num_tokens - 2)
64
+ self.text_embed = TokenEmbedding(num_tokens, embed_dim)
65
+
66
+ # +1 for <eos>
67
+ self.pos_queries = nn.Parameter(torch.Tensor(1, max_label_length + 1, embed_dim))
68
+ self.dropout = nn.Dropout(p=dropout)
69
+ # Encoder has its own init.
70
+ named_apply(partial(init_weights, exclude=['encoder']), self)
71
+ nn.init.trunc_normal_(self.pos_queries, std=0.02)
72
+
73
+ @property
74
+ def _device(self) -> torch.device:
75
+ return next(self.head.parameters(recurse=False)).device
76
+
77
+ @torch.jit.ignore
78
+ def no_weight_decay(self):
79
+ param_names = {'text_embed.embedding.weight', 'pos_queries'}
80
+ enc_param_names = {'encoder.' + n for n in self.encoder.no_weight_decay()}
81
+ return param_names.union(enc_param_names)
82
+
83
+ def encode(self, img: torch.Tensor):
84
+ return self.encoder(img)
85
+
86
+ def decode(
87
+ self,
88
+ tgt: torch.Tensor,
89
+ memory: torch.Tensor,
90
+ tgt_mask: Optional[Tensor] = None,
91
+ tgt_padding_mask: Optional[Tensor] = None,
92
+ tgt_query: Optional[Tensor] = None,
93
+ tgt_query_mask: Optional[Tensor] = None,
94
+ ):
95
+ N, L = tgt.shape
96
+ # <bos> stands for the null context. We only supply position information for characters after <bos>.
97
+ null_ctx = self.text_embed(tgt[:, :1])
98
+ tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:])
99
+ tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
100
+ if tgt_query is None:
101
+ tgt_query = self.pos_queries[:, :L].expand(N, -1, -1)
102
+ tgt_query = self.dropout(tgt_query)
103
+ return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask)
104
+
105
+ def forward(self, tokenizer: Tokenizer, images: Tensor, max_length: Optional[int] = None) -> Tensor:
106
+ testing = max_length is None
107
+ max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
108
+ bs = images.shape[0]
109
+ # +1 for <eos> at end of sequence.
110
+ num_steps = max_length + 1
111
+ memory = self.encode(images)
112
+
113
+ # Query positions up to `num_steps`
114
+ pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
115
+
116
+ # Special case for the forward permutation. Faster than using `generate_attn_masks()`
117
+ tgt_mask = query_mask = torch.triu(torch.ones((num_steps, num_steps), dtype=torch.bool, device=self._device), 1)
118
+
119
+ if self.decode_ar:
120
+ tgt_in = torch.full((bs, num_steps), tokenizer.pad_id, dtype=torch.long, device=self._device)
121
+ tgt_in[:, 0] = tokenizer.bos_id
122
+
123
+ logits = []
124
+ for i in range(num_steps):
125
+ j = i + 1 # next token index
126
+ # Efficient decoding:
127
+ # Input the context up to the ith token. We use only one query (at position = i) at a time.
128
+ # This works because of the lookahead masking effect of the canonical (forward) AR context.
129
+ # Past tokens have no access to future tokens, hence are fixed once computed.
130
+ tgt_out = self.decode(
131
+ tgt_in[:, :j],
132
+ memory,
133
+ tgt_mask[:j, :j],
134
+ tgt_query=pos_queries[:, i:j],
135
+ tgt_query_mask=query_mask[i:j, :j],
136
+ )
137
+ # the next token probability is in the output's ith token position
138
+ p_i = self.head(tgt_out)
139
+ logits.append(p_i)
140
+ if j < num_steps:
141
+ # greedy decode. add the next token index to the target input
142
+ tgt_in[:, j] = p_i.squeeze().argmax(-1)
143
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
144
+ if testing and (tgt_in == tokenizer.eos_id).any(dim=-1).all():
145
+ break
146
+
147
+ logits = torch.cat(logits, dim=1)
148
+ else:
149
+ # No prior context, so input is just <bos>. We query all positions.
150
+ tgt_in = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device)
151
+ tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
152
+ logits = self.head(tgt_out)
153
+
154
+ if self.refine_iters:
155
+ # For iterative refinement, we always use a 'cloze' mask.
156
+ # We can derive it from the AR forward mask by unmasking the token context to the right.
157
+ query_mask[torch.triu(torch.ones(num_steps, num_steps, dtype=torch.bool, device=self._device), 2)] = 0
158
+ bos = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device)
159
+ for i in range(self.refine_iters):
160
+ # Prior context is the previous output.
161
+ tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
162
+ # Mask tokens beyond the first EOS token.
163
+ tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(-1) > 0
164
+ tgt_out = self.decode(
165
+ tgt_in, memory, tgt_mask, tgt_padding_mask, pos_queries, query_mask[:, : tgt_in.shape[1]]
166
+ )
167
+ logits = self.head(tgt_out)
168
+
169
+ return logits
IndicPhotoOCR/utils/strhub/models/parseq/modules.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Optional
18
+
19
+ import torch
20
+ from torch import Tensor, nn as nn
21
+ from torch.nn import functional as F
22
+ from torch.nn.modules import transformer
23
+
24
+ from timm.models.vision_transformer import PatchEmbed, VisionTransformer
25
+
26
+
27
+ class DecoderLayer(nn.Module):
28
+ """A Transformer decoder layer supporting two-stream attention (XLNet)
29
+ This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
30
+
31
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', layer_norm_eps=1e-5):
32
+ super().__init__()
33
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
34
+ self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
35
+ # Implementation of Feedforward model
36
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
37
+ self.dropout = nn.Dropout(dropout)
38
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
39
+
40
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
41
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
42
+ self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps)
43
+ self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps)
44
+ self.dropout1 = nn.Dropout(dropout)
45
+ self.dropout2 = nn.Dropout(dropout)
46
+ self.dropout3 = nn.Dropout(dropout)
47
+
48
+ self.activation = transformer._get_activation_fn(activation)
49
+
50
+ def __setstate__(self, state):
51
+ if 'activation' not in state:
52
+ state['activation'] = F.gelu
53
+ super().__setstate__(state)
54
+
55
+ def forward_stream(
56
+ self,
57
+ tgt: Tensor,
58
+ tgt_norm: Tensor,
59
+ tgt_kv: Tensor,
60
+ memory: Tensor,
61
+ tgt_mask: Optional[Tensor],
62
+ tgt_key_padding_mask: Optional[Tensor],
63
+ ):
64
+ """Forward pass for a single stream (i.e. content or query)
65
+ tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
66
+ Both tgt_kv and memory are expected to be LayerNorm'd too.
67
+ memory is LayerNorm'd by ViT.
68
+ """
69
+ tgt2, sa_weights = self.self_attn(
70
+ tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
71
+ )
72
+ tgt = tgt + self.dropout1(tgt2)
73
+
74
+ tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
75
+ tgt = tgt + self.dropout2(tgt2)
76
+
77
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
78
+ tgt = tgt + self.dropout3(tgt2)
79
+ return tgt, sa_weights, ca_weights
80
+
81
+ def forward(
82
+ self,
83
+ query,
84
+ content,
85
+ memory,
86
+ query_mask: Optional[Tensor] = None,
87
+ content_mask: Optional[Tensor] = None,
88
+ content_key_padding_mask: Optional[Tensor] = None,
89
+ update_content: bool = True,
90
+ ):
91
+ query_norm = self.norm_q(query)
92
+ content_norm = self.norm_c(content)
93
+ query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0]
94
+ if update_content:
95
+ content = self.forward_stream(
96
+ content, content_norm, content_norm, memory, content_mask, content_key_padding_mask
97
+ )[0]
98
+ return query, content
99
+
100
+
101
+ class Decoder(nn.Module):
102
+ __constants__ = ['norm']
103
+
104
+ def __init__(self, decoder_layer, num_layers, norm):
105
+ super().__init__()
106
+ self.layers = transformer._get_clones(decoder_layer, num_layers)
107
+ self.num_layers = num_layers
108
+ self.norm = norm
109
+
110
+ def forward(
111
+ self,
112
+ query,
113
+ content,
114
+ memory,
115
+ query_mask: Optional[Tensor] = None,
116
+ content_mask: Optional[Tensor] = None,
117
+ content_key_padding_mask: Optional[Tensor] = None,
118
+ ):
119
+ for i, mod in enumerate(self.layers):
120
+ last = i == len(self.layers) - 1
121
+ query, content = mod(
122
+ query, content, memory, query_mask, content_mask, content_key_padding_mask, update_content=not last
123
+ )
124
+ query = self.norm(query)
125
+ return query
126
+
127
+
128
+ class Encoder(VisionTransformer):
129
+
130
+ def __init__(
131
+ self,
132
+ img_size=224,
133
+ patch_size=16,
134
+ in_chans=3,
135
+ embed_dim=768,
136
+ depth=12,
137
+ num_heads=12,
138
+ mlp_ratio=4.0,
139
+ qkv_bias=True,
140
+ drop_rate=0.0,
141
+ attn_drop_rate=0.0,
142
+ drop_path_rate=0.0,
143
+ embed_layer=PatchEmbed,
144
+ ):
145
+ super().__init__(
146
+ img_size,
147
+ patch_size,
148
+ in_chans,
149
+ embed_dim=embed_dim,
150
+ depth=depth,
151
+ num_heads=num_heads,
152
+ mlp_ratio=mlp_ratio,
153
+ qkv_bias=qkv_bias,
154
+ drop_rate=drop_rate,
155
+ attn_drop_rate=attn_drop_rate,
156
+ drop_path_rate=drop_path_rate,
157
+ embed_layer=embed_layer,
158
+ num_classes=0, # These
159
+ global_pool='', # disable the
160
+ class_token=False, # classifier head.
161
+ )
162
+
163
+ def forward(self, x):
164
+ # Return all tokens
165
+ return self.forward_features(x)
166
+
167
+
168
+ class TokenEmbedding(nn.Module):
169
+
170
+ def __init__(self, charset_size: int, embed_dim: int):
171
+ super().__init__()
172
+ self.embedding = nn.Embedding(charset_size, embed_dim)
173
+ self.embed_dim = embed_dim
174
+
175
+ def forward(self, tokens: torch.Tensor):
176
+ return math.sqrt(self.embed_dim) * self.embedding(tokens)
IndicPhotoOCR/utils/strhub/models/parseq/system.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from itertools import permutations
18
+ from typing import Any, Optional, Sequence
19
+
20
+ import numpy as np
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import Tensor
25
+
26
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
27
+
28
+ from IndicPhotoOCR.utils.strhub.models.base import CrossEntropySystem
29
+
30
+ from .model import PARSeq as Model
31
+
32
+
33
+ class PARSeq(CrossEntropySystem):
34
+
35
+ def __init__(
36
+ self,
37
+ charset_train: str,
38
+ charset_test: str,
39
+ max_label_length: int,
40
+ batch_size: int,
41
+ lr: float,
42
+ warmup_pct: float,
43
+ weight_decay: float,
44
+ img_size: Sequence[int],
45
+ patch_size: Sequence[int],
46
+ embed_dim: int,
47
+ enc_num_heads: int,
48
+ enc_mlp_ratio: int,
49
+ enc_depth: int,
50
+ dec_num_heads: int,
51
+ dec_mlp_ratio: int,
52
+ dec_depth: int,
53
+ perm_num: int,
54
+ perm_forward: bool,
55
+ perm_mirrored: bool,
56
+ decode_ar: bool,
57
+ refine_iters: int,
58
+ dropout: float,
59
+ **kwargs: Any,
60
+ ) -> None:
61
+ super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
62
+ self.save_hyperparameters()
63
+
64
+ self.model = Model(
65
+ len(self.tokenizer),
66
+ max_label_length,
67
+ img_size,
68
+ patch_size,
69
+ embed_dim,
70
+ enc_num_heads,
71
+ enc_mlp_ratio,
72
+ enc_depth,
73
+ dec_num_heads,
74
+ dec_mlp_ratio,
75
+ dec_depth,
76
+ decode_ar,
77
+ refine_iters,
78
+ dropout,
79
+ )
80
+
81
+ # Perm/attn mask stuff
82
+ self.rng = np.random.default_rng()
83
+ self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
84
+ self.perm_forward = perm_forward
85
+ self.perm_mirrored = perm_mirrored
86
+
87
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
88
+ return self.model.forward(self.tokenizer, images, max_length)
89
+
90
+ def gen_tgt_perms(self, tgt):
91
+ """Generate shared permutations for the whole batch.
92
+ This works because the same attention mask can be used for the shorter sequences
93
+ because of the padding mask.
94
+ """
95
+ # We don't permute the position of BOS, we permute EOS separately
96
+ max_num_chars = tgt.shape[1] - 2
97
+ # Special handling for 1-character sequences
98
+ if max_num_chars == 1:
99
+ return torch.arange(3, device=self._device).unsqueeze(0)
100
+ perms = [torch.arange(max_num_chars, device=self._device)] if self.perm_forward else []
101
+ # Additional permutations if needed
102
+ max_perms = math.factorial(max_num_chars)
103
+ if self.perm_mirrored:
104
+ max_perms //= 2
105
+ num_gen_perms = min(self.max_gen_perms, max_perms)
106
+ # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions
107
+ # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars.
108
+ if max_num_chars < 5:
109
+ # Pool of permutations to sample from. We only need the first half (if complementary option is selected)
110
+ # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
111
+ if max_num_chars == 4 and self.perm_mirrored:
112
+ selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
113
+ else:
114
+ selector = list(range(max_perms))
115
+ perm_pool = torch.as_tensor(
116
+ list(permutations(range(max_num_chars), max_num_chars)),
117
+ device=self._device,
118
+ )[selector]
119
+ # If the forward permutation is always selected, no need to add it to the pool for sampling
120
+ if self.perm_forward:
121
+ perm_pool = perm_pool[1:]
122
+ perms = torch.stack(perms)
123
+ if len(perm_pool):
124
+ i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(perms), replace=False)
125
+ perms = torch.cat([perms, perm_pool[i]])
126
+ else:
127
+ perms.extend(
128
+ [torch.randperm(max_num_chars, device=self._device) for _ in range(num_gen_perms - len(perms))]
129
+ )
130
+ perms = torch.stack(perms)
131
+ if self.perm_mirrored:
132
+ # Add complementary pairs
133
+ comp = perms.flip(-1)
134
+ # Stack in such a way that the pairs are next to each other.
135
+ perms = torch.stack([perms, comp]).transpose(0, 1).reshape(-1, max_num_chars)
136
+ # NOTE:
137
+ # The only meaningful way of permuting the EOS position is by moving it one character position at a time.
138
+ # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS
139
+ # positions will always be much less than the number of permutations (unless a low perm_num is set).
140
+ # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly
141
+ # distribute it across the chosen number of permutations.
142
+ # Add position indices of BOS and EOS
143
+ bos_idx = perms.new_zeros((len(perms), 1))
144
+ eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1)
145
+ perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1)
146
+ # Special handling for the reverse direction. This does two things:
147
+ # 1. Reverse context for the characters
148
+ # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode)
149
+ if len(perms) > 1:
150
+ perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=self._device)
151
+ return perms
152
+
153
+ def generate_attn_masks(self, perm):
154
+ """Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens)
155
+ :param perm: the permutation sequence. i = 0 is always the BOS
156
+ :return: lookahead attention masks
157
+ """
158
+ sz = perm.shape[0]
159
+ mask = torch.zeros((sz, sz), dtype=torch.bool, device=self._device)
160
+ for i in range(sz):
161
+ query_idx = perm[i]
162
+ masked_keys = perm[i + 1 :]
163
+ mask[query_idx, masked_keys] = True
164
+ content_mask = mask[:-1, :-1].clone()
165
+ mask[torch.eye(sz, dtype=torch.bool, device=self._device)] = True # mask "self"
166
+ query_mask = mask[1:, :-1]
167
+ return content_mask, query_mask
168
+
169
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
170
+ images, labels = batch
171
+ tgt = self.tokenizer.encode(labels, self._device)
172
+
173
+ # Encode the source sequence (i.e. the image codes)
174
+ memory = self.model.encode(images)
175
+
176
+ # Prepare the target sequences (input and output)
177
+ tgt_perms = self.gen_tgt_perms(tgt)
178
+ tgt_in = tgt[:, :-1]
179
+ tgt_out = tgt[:, 1:]
180
+ # The [EOS] token is not depended upon by any other token in any permutation ordering
181
+ tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
182
+
183
+ loss = 0
184
+ loss_numel = 0
185
+ n = (tgt_out != self.pad_id).sum().item()
186
+ for i, perm in enumerate(tgt_perms):
187
+ tgt_mask, query_mask = self.generate_attn_masks(perm)
188
+ out = self.model.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask)
189
+ logits = self.model.head(out).flatten(end_dim=1)
190
+ loss += n * F.cross_entropy(logits, tgt_out.flatten(), ignore_index=self.pad_id)
191
+ loss_numel += n
192
+ # After the second iteration (i.e. done with canonical and reverse orderings),
193
+ # remove the [EOS] tokens for the succeeding perms
194
+ if i == 1:
195
+ tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id, tgt_out)
196
+ n = (tgt_out != self.pad_id).sum().item()
197
+ loss /= loss_numel
198
+
199
+ self.log('loss', loss)
200
+ return loss
IndicPhotoOCR/utils/strhub/models/trba/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""
2
+ Baek, Jeonghun, Geewook Kim, Junyeop Lee, Sungrae Park, Dongyoon Han, Sangdoo Yun, Seong Joon Oh, and Hwalsuk Lee.
3
+ "What is wrong with scene text recognition model comparisons? dataset and model analysis."
4
+ In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4715-4723. 2019.
5
+
6
+ https://arxiv.org/abs/1904.01906
7
+
8
+ All source files, except `system.py`, are based on the implementation listed below,
9
+ and hence are released under the license of the original.
10
+
11
+ Source: https://github.com/clovaai/deep-text-recognition-benchmark
12
+ License: Apache License 2.0 (see LICENSE file in project root)
13
+ """
IndicPhotoOCR/utils/strhub/models/trba/feature_extraction.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from torchvision.models.resnet import BasicBlock
4
+
5
+
6
+ class ResNet_FeatureExtractor(nn.Module):
7
+ """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
8
+
9
+ def __init__(self, input_channel, output_channel=512):
10
+ super().__init__()
11
+ self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3])
12
+
13
+ def forward(self, input):
14
+ return self.ConvNet(input)
15
+
16
+
17
+ class ResNet(nn.Module):
18
+
19
+ def __init__(self, input_channel, output_channel, block, layers):
20
+ super().__init__()
21
+
22
+ self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
23
+
24
+ self.inplanes = int(output_channel / 8)
25
+ self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
26
+ kernel_size=3, stride=1, padding=1, bias=False)
27
+ self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
28
+ self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
29
+ kernel_size=3, stride=1, padding=1, bias=False)
30
+ self.bn0_2 = nn.BatchNorm2d(self.inplanes)
31
+ self.relu = nn.ReLU(inplace=True)
32
+
33
+ self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
34
+ self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
35
+ self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
36
+ 0], kernel_size=3, stride=1, padding=1, bias=False)
37
+ self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
38
+
39
+ self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
40
+ self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
41
+ self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
42
+ 1], kernel_size=3, stride=1, padding=1, bias=False)
43
+ self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
44
+
45
+ self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
46
+ self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
47
+ self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
48
+ 2], kernel_size=3, stride=1, padding=1, bias=False)
49
+ self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
50
+
51
+ self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
52
+ self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
53
+ 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
54
+ self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
55
+ self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
56
+ 3], kernel_size=2, stride=1, padding=0, bias=False)
57
+ self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
58
+
59
+ def _make_layer(self, block, planes, blocks, stride=1):
60
+ downsample = None
61
+ if stride != 1 or self.inplanes != planes * block.expansion:
62
+ downsample = nn.Sequential(
63
+ nn.Conv2d(self.inplanes, planes * block.expansion,
64
+ kernel_size=1, stride=stride, bias=False),
65
+ nn.BatchNorm2d(planes * block.expansion),
66
+ )
67
+
68
+ layers = []
69
+ layers.append(block(self.inplanes, planes, stride, downsample))
70
+ self.inplanes = planes * block.expansion
71
+ for i in range(1, blocks):
72
+ layers.append(block(self.inplanes, planes))
73
+
74
+ return nn.Sequential(*layers)
75
+
76
+ def forward(self, x):
77
+ x = self.conv0_1(x)
78
+ x = self.bn0_1(x)
79
+ x = self.relu(x)
80
+ x = self.conv0_2(x)
81
+ x = self.bn0_2(x)
82
+ x = self.relu(x)
83
+
84
+ x = self.maxpool1(x)
85
+ x = self.layer1(x)
86
+ x = self.conv1(x)
87
+ x = self.bn1(x)
88
+ x = self.relu(x)
89
+
90
+ x = self.maxpool2(x)
91
+ x = self.layer2(x)
92
+ x = self.conv2(x)
93
+ x = self.bn2(x)
94
+ x = self.relu(x)
95
+
96
+ x = self.maxpool3(x)
97
+ x = self.layer3(x)
98
+ x = self.conv3(x)
99
+ x = self.bn3(x)
100
+ x = self.relu(x)
101
+
102
+ x = self.layer4(x)
103
+ x = self.conv4_1(x)
104
+ x = self.bn4_1(x)
105
+ x = self.relu(x)
106
+ x = self.conv4_2(x)
107
+ x = self.bn4_2(x)
108
+ x = self.relu(x)
109
+
110
+ return x
IndicPhotoOCR/utils/strhub/models/trba/model.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from strhub.models.modules import BidirectionalLSTM
4
+ from .feature_extraction import ResNet_FeatureExtractor
5
+ from .prediction import Attention
6
+ from .transformation import TPS_SpatialTransformerNetwork
7
+
8
+
9
+ class TRBA(nn.Module):
10
+
11
+ def __init__(self, img_h, img_w, num_class, num_fiducial=20, input_channel=3, output_channel=512, hidden_size=256,
12
+ use_ctc=False):
13
+ super().__init__()
14
+ """ Transformation """
15
+ self.Transformation = TPS_SpatialTransformerNetwork(
16
+ F=num_fiducial, I_size=(img_h, img_w), I_r_size=(img_h, img_w),
17
+ I_channel_num=input_channel)
18
+
19
+ """ FeatureExtraction """
20
+ self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel)
21
+ self.FeatureExtraction_output = output_channel
22
+ self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
23
+
24
+ """ Sequence modeling"""
25
+ self.SequenceModeling = nn.Sequential(
26
+ BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
27
+ BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
28
+ self.SequenceModeling_output = hidden_size
29
+
30
+ """ Prediction """
31
+ if use_ctc:
32
+ self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)
33
+ else:
34
+ self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class)
35
+
36
+ def forward(self, image, max_label_length, text=None):
37
+ """ Transformation stage """
38
+ image = self.Transformation(image)
39
+
40
+ """ Feature extraction stage """
41
+ visual_feature = self.FeatureExtraction(image)
42
+ visual_feature = visual_feature.permute(0, 3, 1, 2) # [b, c, h, w] -> [b, w, c, h]
43
+ visual_feature = self.AdaptiveAvgPool(visual_feature) # [b, w, c, h] -> [b, w, c, 1]
44
+ visual_feature = visual_feature.squeeze(3) # [b, w, c, 1] -> [b, w, c]
45
+
46
+ """ Sequence modeling stage """
47
+ contextual_feature = self.SequenceModeling(visual_feature) # [b, num_steps, hidden_size]
48
+
49
+ """ Prediction stage """
50
+ if isinstance(self.Prediction, Attention):
51
+ prediction = self.Prediction(contextual_feature.contiguous(), text, max_label_length)
52
+ else:
53
+ prediction = self.Prediction(contextual_feature.contiguous()) # CTC
54
+
55
+ return prediction # [b, num_steps, num_class]
IndicPhotoOCR/utils/strhub/models/trba/prediction.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Attention(nn.Module):
7
+
8
+ def __init__(self, input_size, hidden_size, num_class, num_char_embeddings=256):
9
+ super().__init__()
10
+ self.attention_cell = AttentionCell(input_size, hidden_size, num_char_embeddings)
11
+ self.hidden_size = hidden_size
12
+ self.num_class = num_class
13
+ self.generator = nn.Linear(hidden_size, num_class)
14
+ self.char_embeddings = nn.Embedding(num_class, num_char_embeddings)
15
+
16
+ def forward(self, batch_H, text, max_label_length=25):
17
+ """
18
+ input:
19
+ batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_class]
20
+ text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [SOS] token. text[:, 0] = [SOS].
21
+ output: probability distribution at each step [batch_size x num_steps x num_class]
22
+ """
23
+ batch_size = batch_H.size(0)
24
+ num_steps = max_label_length + 1 # +1 for [EOS] at end of sentence.
25
+
26
+ output_hiddens = batch_H.new_zeros((batch_size, num_steps, self.hidden_size), dtype=torch.float)
27
+ hidden = (batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float),
28
+ batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float))
29
+
30
+ if self.training:
31
+ for i in range(num_steps):
32
+ char_embeddings = self.char_embeddings(text[:, i])
33
+ # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_embeddings : f(y_{t-1})
34
+ hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings)
35
+ output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell)
36
+ probs = self.generator(output_hiddens)
37
+
38
+ else:
39
+ targets = text[0].expand(batch_size) # should be fill with [SOS] token
40
+ probs = batch_H.new_zeros((batch_size, num_steps, self.num_class), dtype=torch.float)
41
+
42
+ for i in range(num_steps):
43
+ char_embeddings = self.char_embeddings(targets)
44
+ hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings)
45
+ probs_step = self.generator(hidden[0])
46
+ probs[:, i, :] = probs_step
47
+ _, next_input = probs_step.max(1)
48
+ targets = next_input
49
+
50
+ return probs # batch_size x num_steps x num_class
51
+
52
+
53
+ class AttentionCell(nn.Module):
54
+
55
+ def __init__(self, input_size, hidden_size, num_embeddings):
56
+ super().__init__()
57
+ self.i2h = nn.Linear(input_size, hidden_size, bias=False)
58
+ self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias
59
+ self.score = nn.Linear(hidden_size, 1, bias=False)
60
+ self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
61
+ self.hidden_size = hidden_size
62
+
63
+ def forward(self, prev_hidden, batch_H, char_embeddings):
64
+ # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
65
+ batch_H_proj = self.i2h(batch_H)
66
+ prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
67
+ e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
68
+
69
+ alpha = F.softmax(e, dim=1)
70
+ context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
71
+ concat_context = torch.cat([context, char_embeddings], 1) # batch_size x (num_channel + num_embedding)
72
+ cur_hidden = self.rnn(concat_context, prev_hidden)
73
+ return cur_hidden, alpha