Isaacgonzales commited on
Commit
d02e83e
1 Parent(s): d59ff1a
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +39 -0
  2. model.py +52 -0
  3. post.py +42 -0
  4. requirements.txt +7 -0
  5. strhub/__init__.py +0 -0
  6. strhub/__pycache__/__init__.cpython-37.pyc +0 -0
  7. strhub/data/.ipynb_checkpoints/dataset-checkpoint.py +137 -0
  8. strhub/data/.ipynb_checkpoints/module-checkpoint.py +107 -0
  9. strhub/data/__init__.py +0 -0
  10. strhub/data/__pycache__/__init__.cpython-37.pyc +0 -0
  11. strhub/data/__pycache__/aa_overrides.cpython-37.pyc +0 -0
  12. strhub/data/__pycache__/augment.cpython-37.pyc +0 -0
  13. strhub/data/__pycache__/dataset.cpython-37.pyc +0 -0
  14. strhub/data/__pycache__/module.cpython-37.pyc +0 -0
  15. strhub/data/__pycache__/utils.cpython-37.pyc +0 -0
  16. strhub/data/aa_overrides.py +46 -0
  17. strhub/data/augment.py +111 -0
  18. strhub/data/dataset.py +137 -0
  19. strhub/data/module.py +107 -0
  20. strhub/data/utils.py +148 -0
  21. strhub/models/.ipynb_checkpoints/base-checkpoint.py +202 -0
  22. strhub/models/.ipynb_checkpoints/modules-checkpoint.py +20 -0
  23. strhub/models/.ipynb_checkpoints/utils-checkpoint.py +123 -0
  24. strhub/models/__init__.py +0 -0
  25. strhub/models/__pycache__/__init__.cpython-37.pyc +0 -0
  26. strhub/models/__pycache__/base.cpython-37.pyc +0 -0
  27. strhub/models/__pycache__/utils.cpython-37.pyc +0 -0
  28. strhub/models/abinet/LICENSE +25 -0
  29. strhub/models/abinet/__init__.py +13 -0
  30. strhub/models/abinet/attention.py +100 -0
  31. strhub/models/abinet/backbone.py +24 -0
  32. strhub/models/abinet/model.py +31 -0
  33. strhub/models/abinet/model_abinet_iter.py +39 -0
  34. strhub/models/abinet/model_alignment.py +28 -0
  35. strhub/models/abinet/model_language.py +50 -0
  36. strhub/models/abinet/model_vision.py +45 -0
  37. strhub/models/abinet/resnet.py +72 -0
  38. strhub/models/abinet/system.py +172 -0
  39. strhub/models/abinet/transformer.py +143 -0
  40. strhub/models/base.py +202 -0
  41. strhub/models/crnn/LICENSE +21 -0
  42. strhub/models/crnn/__init__.py +13 -0
  43. strhub/models/crnn/model.py +62 -0
  44. strhub/models/crnn/system.py +43 -0
  45. strhub/models/modules.py +20 -0
  46. strhub/models/parseq/__init__.py +0 -0
  47. strhub/models/parseq/__pycache__/__init__.cpython-37.pyc +0 -0
  48. strhub/models/parseq/__pycache__/modules.cpython-37.pyc +0 -0
  49. strhub/models/parseq/__pycache__/system.cpython-37.pyc +0 -0
  50. strhub/models/parseq/modules.py +126 -0
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import shutil
3
+
4
+ import urllib.request
5
+ import sys
6
+ import os
7
+ import urllib.request
8
+ import zipfile
9
+
10
+ sys.path.append(".")
11
+
12
+ from model import prediction
13
+
14
+ gr.close_all()
15
+
16
+ https://storage.googleapis.com/models-gradio/products/products.zip
17
+
18
+ urllib.request.urlretrieve("https://storage.googleapis.com/models-gradio/products/products.zip")
19
+
20
+
21
+ with zipfile.ZipFile("products.zip", 'r') as zip_ref:
22
+ zip_ref.extractall()
23
+
24
+ def predict(img):
25
+ name_image = img.split("/")[-1]
26
+
27
+ prediction_img, text = prediction(img)
28
+
29
+ return str(text), prediction_img,
30
+
31
+
32
+ sample_images = ["dataset/" + name for name in os.listdir("dataset")]
33
+
34
+
35
+ gr.Interface(fn=predict,
36
+ inputs=[gr.Image(label="image à tester" ,type="filepath")],
37
+ outputs=[gr.Textbox(label="analyse"), gr.Image(label ="résultat") ],
38
+ css="footer {visibility: hidden} body}, .gradio-container {background-color: white}",
39
+ examples=sample_images).launch(server_name="0.0.0.0", share=False)
model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from strhub.data.module import SceneTextDataModule
3
+ from strhub.models.utils import load_from_checkpoint
4
+ from post import filter_mask
5
+ import segmentation_models_pytorch as smp
6
+ import albumentations as albu
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ import torch
10
+ import cv2
11
+
12
+ model_recog = load_from_checkpoint("weights/parseq/last.ckpt").eval().to("cpu")
13
+ img_transform = SceneTextDataModule.get_transform(model_recog.hparams.img_size)
14
+
15
+ model = torch.load('weights/best_model.pth').to("cpu")
16
+ model.eval()
17
+ model.float()
18
+
19
+ SHAPE_X = 384
20
+ SHAPE_Y = 384
21
+
22
+
23
+ def prediction(image_path):
24
+ image = cv2.imread(image_path)
25
+ image_original = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
26
+ preprocessing_fn = smp.encoders.get_preprocessing_fn('resnet50')
27
+ transform = albu.Compose([
28
+ albu.Lambda(image=preprocessing_fn), albu.Resize(SHAPE_X, SHAPE_Y)
29
+ ])
30
+
31
+ image_result = transform(image=image_original)["image"]
32
+ transform = transforms.ToTensor()
33
+ tensor = transform(image_result)
34
+ tensor = torch.unsqueeze(tensor, 0)
35
+ output = model.predict(tensor.float())
36
+
37
+ result, img_vis = filter_mask(output, image_original )
38
+
39
+ image = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
40
+ im_pil = Image.fromarray(image)
41
+ image = img_transform(im_pil).unsqueeze(0).to("cpu")
42
+
43
+ p = model_recog(image).softmax(-1)
44
+ pred, p = model_recog.tokenizer.decode(p)
45
+ print(f'{image_path}: {pred[0]}')
46
+
47
+
48
+ return img_vis, pred[0]
49
+
50
+
51
+
52
+
post.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ def filter_mask(output, image):
5
+ image_h, image_w = image.shape[:2]
6
+
7
+ predict_mask = (output.squeeze().cpu().numpy().round())
8
+ predict_mask = predict_mask.astype('uint8')*255
9
+ predict_mask = cv2.resize( predict_mask, (image_w, image_h) )
10
+
11
+ ret, thresh = cv2.threshold(predict_mask, 127, 255, 0)
12
+ contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
13
+ points = contours[0]
14
+
15
+ rect = cv2.minAreaRect(points)
16
+ box = cv2.boxPoints(rect)
17
+ box = np.int0(box)
18
+
19
+ img_vis = cv2.drawContours(image.copy(),[box],0,(255,0,0), 6)
20
+
21
+ (cX, cY), (w, h), angle = rect
22
+
23
+ if w<h:
24
+ angle -= 90
25
+
26
+ #(cX, cY) = (image_w // 2, image_h // 2)
27
+
28
+ M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
29
+ rotated = cv2.warpAffine(image, M, (round(image_w), round(image_h)))
30
+
31
+ ones = np.ones(shape=(len(points), 1))
32
+ points = np.squeeze(points)
33
+ points_ones = np.hstack([points, ones])
34
+ points_rotate = M.dot(points_ones.T).T
35
+
36
+ (cX, cY), (w, h), angle = cv2.minAreaRect(points_rotate.round().astype(int))
37
+
38
+ if angle < 45 :
39
+ crop_img = rotated[round(cY-h//2):round(cY+h//2), round(cX - w//2):round(cX + w//2)]
40
+ else:
41
+ crop_img = rotated[round(cY-w//2):round(cY+w//2), round(cX - h//2):round(cX + h//2)]
42
+ return crop_img, img_vis
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ opencv-python
2
+ torch
3
+ torchvision
4
+ gradio
5
+ segmentation_models_pytorch
6
+ albumentations
7
+ Pillow
strhub/__init__.py ADDED
File without changes
strhub/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (134 Bytes). View file
 
strhub/data/.ipynb_checkpoints/dataset-checkpoint.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import glob
16
+ import io
17
+ import logging
18
+ import unicodedata
19
+ from pathlib import Path, PurePath
20
+ from typing import Callable, Optional, Union
21
+
22
+ import lmdb
23
+ from PIL import Image
24
+ from torch.utils.data import Dataset, ConcatDataset
25
+
26
+ from strhub.data.utils import CharsetAdapter
27
+
28
+ log = logging.getLogger(__name__)
29
+
30
+
31
+ def build_tree_dataset(root: Union[PurePath, str], *args, **kwargs):
32
+ try:
33
+ kwargs.pop('root') # prevent 'root' from being passed via kwargs
34
+ except KeyError:
35
+ pass
36
+ root = Path(root).absolute()
37
+ log.info(f'dataset root:\t{root}')
38
+ datasets = []
39
+ for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True):
40
+ mdb = Path(mdb)
41
+ ds_name = str(mdb.parent.relative_to(root))
42
+ ds_root = str(mdb.parent.absolute())
43
+ dataset = LmdbDataset(ds_root, *args, **kwargs)
44
+ log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}')
45
+ datasets.append(dataset)
46
+ return ConcatDataset(datasets)
47
+
48
+
49
+ class LmdbDataset(Dataset):
50
+ """Dataset interface to an LMDB database.
51
+
52
+ It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned
53
+ as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset.
54
+ Labels are transformed according to the charset.
55
+ """
56
+
57
+ def __init__(self, root: str, charset: str, max_label_len: int, min_image_dim: int = 0,
58
+ remove_whitespace: bool = True, normalize_unicode: bool = True,
59
+ unlabelled: bool = False, transform: Optional[Callable] = None):
60
+ self._env = None
61
+ self.root = root
62
+ self.unlabelled = unlabelled
63
+ self.transform = transform
64
+ self.labels = []
65
+ self.filtered_index_list = []
66
+ self.num_samples = self._preprocess_labels(charset, remove_whitespace, normalize_unicode,
67
+ max_label_len, min_image_dim)
68
+
69
+ def __del__(self):
70
+ if self._env is not None:
71
+ self._env.close()
72
+ self._env = None
73
+
74
+ def _create_env(self):
75
+ return lmdb.open(self.root, max_readers=1, readonly=True, create=False,
76
+ readahead=False, meminit=False, lock=False)
77
+
78
+ @property
79
+ def env(self):
80
+ if self._env is None:
81
+ self._env = self._create_env()
82
+ return self._env
83
+
84
+ def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim):
85
+ charset_adapter = CharsetAdapter(charset)
86
+ with self._create_env() as env, env.begin() as txn:
87
+ num_samples = int(txn.get('num-samples'.encode()))
88
+ if self.unlabelled:
89
+ return num_samples
90
+ for index in range(num_samples):
91
+ index += 1 # lmdb starts with 1
92
+ label_key = f'label-{index:09d}'.encode()
93
+ label = txn.get(label_key).decode()
94
+ # Normally, whitespace is removed from the labels.
95
+ if remove_whitespace:
96
+ label = ''.join(label.split())
97
+ # Normalize unicode composites (if any) and convert to compatible ASCII characters
98
+ if normalize_unicode:
99
+ label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode()
100
+ # Filter by length before removing unsupported characters. The original label might be too long.
101
+ if len(label) > max_label_len:
102
+ continue
103
+ label = charset_adapter(label)
104
+ # We filter out samples which don't contain any supported characters
105
+ if not label:
106
+ continue
107
+ # Filter images that are too small.
108
+ if min_image_dim > 0:
109
+ img_key = f'image-{index:09d}'.encode()
110
+ buf = io.BytesIO(txn.get(img_key))
111
+ w, h = Image.open(buf).size
112
+ if w < self.min_image_dim or h < self.min_image_dim:
113
+ continue
114
+ self.labels.append(label)
115
+ self.filtered_index_list.append(index)
116
+ return len(self.labels)
117
+
118
+ def __len__(self):
119
+ return self.num_samples
120
+
121
+ def __getitem__(self, index):
122
+ if self.unlabelled:
123
+ label = index
124
+ else:
125
+ label = self.labels[index]
126
+ index = self.filtered_index_list[index]
127
+
128
+ img_key = f'image-{index:09d}'.encode()
129
+ with self.env.begin() as txn:
130
+ imgbuf = txn.get(img_key)
131
+ buf = io.BytesIO(imgbuf)
132
+ img = Image.open(buf).convert('RGB')
133
+
134
+ if self.transform is not None:
135
+ img = self.transform(img)
136
+
137
+ return img, label
strhub/data/.ipynb_checkpoints/module-checkpoint.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pathlib import PurePath
17
+ from typing import Optional, Callable, Sequence, Tuple
18
+
19
+ import pytorch_lightning as pl
20
+ from torch.utils.data import DataLoader
21
+ from torchvision import transforms as T
22
+
23
+ from .dataset import build_tree_dataset, LmdbDataset
24
+
25
+
26
+ class SceneTextDataModule(pl.LightningDataModule):
27
+ TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80')
28
+ TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80')
29
+ TEST_NEW = ('ArT', 'COCOv1.4', 'Uber')
30
+ TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW))
31
+
32
+ def __init__(self, root_dir: str, train_dir: str, img_size: Sequence[int], max_label_length: int,
33
+ charset_train: str, charset_test: str, batch_size: int, num_workers: int, augment: bool,
34
+ remove_whitespace: bool = True, normalize_unicode: bool = True,
35
+ min_image_dim: int = 0, rotation: int = 0, collate_fn: Optional[Callable] = None):
36
+ super().__init__()
37
+ self.root_dir = root_dir
38
+ self.train_dir = train_dir
39
+ self.img_size = tuple(img_size)
40
+ self.max_label_length = max_label_length
41
+ self.charset_train = charset_train
42
+ self.charset_test = charset_test
43
+ self.batch_size = batch_size
44
+ self.num_workers = num_workers
45
+ self.augment = augment
46
+ self.remove_whitespace = remove_whitespace
47
+ self.normalize_unicode = normalize_unicode
48
+ self.min_image_dim = min_image_dim
49
+ self.rotation = rotation
50
+ self.collate_fn = collate_fn
51
+ self._train_dataset = None
52
+ self._val_dataset = None
53
+
54
+ @staticmethod
55
+ def get_transform(img_size: Tuple[int], augment: bool = False, rotation: int = 0):
56
+ transforms = []
57
+ if augment:
58
+ from .augment import rand_augment_transform
59
+ transforms.append(rand_augment_transform())
60
+ if rotation:
61
+ transforms.append(lambda img: img.rotate(rotation, expand=True))
62
+ transforms.extend([
63
+ T.Resize(img_size, T.InterpolationMode.BICUBIC),
64
+ T.ToTensor(),
65
+ T.Normalize(0.5, 0.5)
66
+ ])
67
+ return T.Compose(transforms)
68
+
69
+ @property
70
+ def train_dataset(self):
71
+ if self._train_dataset is None:
72
+ transform = self.get_transform(self.img_size, self.augment)
73
+ root = PurePath(self.root_dir, 'train', self.train_dir)
74
+ self._train_dataset = build_tree_dataset(root, self.charset_train, self.max_label_length,
75
+ self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
76
+ transform=transform)
77
+ return self._train_dataset
78
+
79
+ @property
80
+ def val_dataset(self):
81
+ if self._val_dataset is None:
82
+ transform = self.get_transform(self.img_size)
83
+ root = PurePath(self.root_dir, 'val')
84
+ self._val_dataset = build_tree_dataset(root, self.charset_test, self.max_label_length,
85
+ self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
86
+ transform=transform)
87
+ return self._val_dataset
88
+
89
+ def train_dataloader(self):
90
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
91
+ num_workers=self.num_workers, persistent_workers=self.num_workers > 0,
92
+ pin_memory=True, collate_fn=self.collate_fn)
93
+
94
+ def val_dataloader(self):
95
+ return DataLoader(self.val_dataset, batch_size=self.batch_size,
96
+ num_workers=self.num_workers, persistent_workers=self.num_workers > 0,
97
+ pin_memory=True, collate_fn=self.collate_fn)
98
+
99
+ def test_dataloaders(self, subset):
100
+ transform = self.get_transform(self.img_size, rotation=self.rotation)
101
+ root = PurePath(self.root_dir, 'test')
102
+ datasets = {s: LmdbDataset(str(root / s), self.charset_test, self.max_label_length,
103
+ self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
104
+ transform=transform) for s in subset}
105
+ return {k: DataLoader(v, batch_size=self.batch_size, num_workers=self.num_workers,
106
+ pin_memory=True, collate_fn=self.collate_fn)
107
+ for k, v in datasets.items()}
strhub/data/__init__.py ADDED
File without changes
strhub/data/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (139 Bytes). View file
 
strhub/data/__pycache__/aa_overrides.cpython-37.pyc ADDED
Binary file (1.22 kB). View file
 
strhub/data/__pycache__/augment.cpython-37.pyc ADDED
Binary file (3.54 kB). View file
 
strhub/data/__pycache__/dataset.cpython-37.pyc ADDED
Binary file (3.97 kB). View file
 
strhub/data/__pycache__/module.cpython-37.pyc ADDED
Binary file (4.11 kB). View file
 
strhub/data/__pycache__/utils.cpython-37.pyc ADDED
Binary file (6.83 kB). View file
 
strhub/data/aa_overrides.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Extends default ops to accept optional parameters."""
17
+ from functools import partial
18
+
19
+ from timm.data.auto_augment import _LEVEL_DENOM, _randomly_negate, LEVEL_TO_ARG, NAME_TO_OP, rotate
20
+
21
+
22
+ def rotate_expand(img, degrees, **kwargs):
23
+ """Rotate operation with expand=True to avoid cutting off the characters"""
24
+ kwargs['expand'] = True
25
+ return rotate(img, degrees, **kwargs)
26
+
27
+
28
+ def _level_to_arg(level, hparams, key, default):
29
+ magnitude = hparams.get(key, default)
30
+ level = (level / _LEVEL_DENOM) * magnitude
31
+ level = _randomly_negate(level)
32
+ return level,
33
+
34
+
35
+ def apply():
36
+ # Overrides
37
+ NAME_TO_OP.update({
38
+ 'Rotate': rotate_expand
39
+ })
40
+ LEVEL_TO_ARG.update({
41
+ 'Rotate': partial(_level_to_arg, key='rotate_deg', default=30.),
42
+ 'ShearX': partial(_level_to_arg, key='shear_x_pct', default=0.3),
43
+ 'ShearY': partial(_level_to_arg, key='shear_y_pct', default=0.3),
44
+ 'TranslateXRel': partial(_level_to_arg, key='translate_x_pct', default=0.45),
45
+ 'TranslateYRel': partial(_level_to_arg, key='translate_y_pct', default=0.45),
46
+ })
strhub/data/augment.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+
18
+ import imgaug.augmenters as iaa
19
+ import numpy as np
20
+ from PIL import ImageFilter, Image
21
+ from timm.data import auto_augment
22
+
23
+ from strhub.data import aa_overrides
24
+
25
+ aa_overrides.apply()
26
+
27
+ _OP_CACHE = {}
28
+
29
+
30
+ def _get_op(key, factory):
31
+ try:
32
+ op = _OP_CACHE[key]
33
+ except KeyError:
34
+ op = factory()
35
+ _OP_CACHE[key] = op
36
+ return op
37
+
38
+
39
+ def _get_param(level, img, max_dim_factor, min_level=1):
40
+ max_level = max(min_level, max_dim_factor * max(img.size))
41
+ return round(min(level, max_level))
42
+
43
+
44
+ def gaussian_blur(img, radius, **__):
45
+ radius = _get_param(radius, img, 0.02)
46
+ key = 'gaussian_blur_' + str(radius)
47
+ op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius))
48
+ return img.filter(op)
49
+
50
+
51
+ def motion_blur(img, k, **__):
52
+ k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values
53
+ key = 'motion_blur_' + str(k)
54
+ op = _get_op(key, lambda: iaa.MotionBlur(k))
55
+ return Image.fromarray(op(image=np.asarray(img)))
56
+
57
+
58
+ def gaussian_noise(img, scale, **_):
59
+ scale = _get_param(scale, img, 0.25) | 1 # bin to odd values
60
+ key = 'gaussian_noise_' + str(scale)
61
+ op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale))
62
+ return Image.fromarray(op(image=np.asarray(img)))
63
+
64
+
65
+ def poisson_noise(img, lam, **_):
66
+ lam = _get_param(lam, img, 0.2) | 1 # bin to odd values
67
+ key = 'poisson_noise_' + str(lam)
68
+ op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam))
69
+ return Image.fromarray(op(image=np.asarray(img)))
70
+
71
+
72
+ def _level_to_arg(level, _hparams, max):
73
+ level = max * level / auto_augment._LEVEL_DENOM
74
+ return level,
75
+
76
+
77
+ _RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy()
78
+ _RAND_TRANSFORMS.remove('SharpnessIncreasing') # remove, interferes with *blur ops
79
+ _RAND_TRANSFORMS.extend([
80
+ 'GaussianBlur',
81
+ # 'MotionBlur',
82
+ # 'GaussianNoise',
83
+ 'PoissonNoise'
84
+ ])
85
+ auto_augment.LEVEL_TO_ARG.update({
86
+ 'GaussianBlur': partial(_level_to_arg, max=4),
87
+ 'MotionBlur': partial(_level_to_arg, max=20),
88
+ 'GaussianNoise': partial(_level_to_arg, max=0.1 * 255),
89
+ 'PoissonNoise': partial(_level_to_arg, max=40)
90
+ })
91
+ auto_augment.NAME_TO_OP.update({
92
+ 'GaussianBlur': gaussian_blur,
93
+ 'MotionBlur': motion_blur,
94
+ 'GaussianNoise': gaussian_noise,
95
+ 'PoissonNoise': poisson_noise
96
+ })
97
+
98
+
99
+ def rand_augment_transform(magnitude=5, num_layers=3):
100
+ # These are tuned for magnitude=5, which means that effective magnitudes are half of these values.
101
+ hparams = {
102
+ 'rotate_deg': 30,
103
+ 'shear_x_pct': 0.9,
104
+ 'shear_y_pct': 0.2,
105
+ 'translate_x_pct': 0.10,
106
+ 'translate_y_pct': 0.30
107
+ }
108
+ ra_ops = auto_augment.rand_augment_ops(magnitude, hparams, transforms=_RAND_TRANSFORMS)
109
+ # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice)
110
+ choice_weights = [1. / len(ra_ops) for _ in range(len(ra_ops))]
111
+ return auto_augment.RandAugment(ra_ops, num_layers, choice_weights)
strhub/data/dataset.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import glob
16
+ import io
17
+ import logging
18
+ import unicodedata
19
+ from pathlib import Path, PurePath
20
+ from typing import Callable, Optional, Union
21
+
22
+ import lmdb
23
+ from PIL import Image
24
+ from torch.utils.data import Dataset, ConcatDataset
25
+
26
+ from strhub.data.utils import CharsetAdapter
27
+
28
+ log = logging.getLogger(__name__)
29
+
30
+
31
+ def build_tree_dataset(root: Union[PurePath, str], *args, **kwargs):
32
+ try:
33
+ kwargs.pop('root') # prevent 'root' from being passed via kwargs
34
+ except KeyError:
35
+ pass
36
+ root = Path(root).absolute()
37
+ log.info(f'dataset root:\t{root}')
38
+ datasets = []
39
+ for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True):
40
+ mdb = Path(mdb)
41
+ ds_name = str(mdb.parent.relative_to(root))
42
+ ds_root = str(mdb.parent.absolute())
43
+ dataset = LmdbDataset(ds_root, *args, **kwargs)
44
+ log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}')
45
+ datasets.append(dataset)
46
+ return ConcatDataset(datasets)
47
+
48
+
49
+ class LmdbDataset(Dataset):
50
+ """Dataset interface to an LMDB database.
51
+
52
+ It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned
53
+ as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset.
54
+ Labels are transformed according to the charset.
55
+ """
56
+
57
+ def __init__(self, root: str, charset: str, max_label_len: int, min_image_dim: int = 0,
58
+ remove_whitespace: bool = True, normalize_unicode: bool = True,
59
+ unlabelled: bool = False, transform: Optional[Callable] = None):
60
+ self._env = None
61
+ self.root = root
62
+ self.unlabelled = unlabelled
63
+ self.transform = transform
64
+ self.labels = []
65
+ self.filtered_index_list = []
66
+ self.num_samples = self._preprocess_labels(charset, remove_whitespace, normalize_unicode,
67
+ max_label_len, min_image_dim)
68
+
69
+ def __del__(self):
70
+ if self._env is not None:
71
+ self._env.close()
72
+ self._env = None
73
+
74
+ def _create_env(self):
75
+ return lmdb.open(self.root, max_readers=1, readonly=True, create=False,
76
+ readahead=False, meminit=False, lock=False)
77
+
78
+ @property
79
+ def env(self):
80
+ if self._env is None:
81
+ self._env = self._create_env()
82
+ return self._env
83
+
84
+ def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim):
85
+ charset_adapter = CharsetAdapter(charset)
86
+ with self._create_env() as env, env.begin() as txn:
87
+ num_samples = int(txn.get('num-samples'.encode()))
88
+ if self.unlabelled:
89
+ return num_samples
90
+ for index in range(num_samples):
91
+ index += 1 # lmdb starts with 1
92
+ label_key = f'label-{index:09d}'.encode()
93
+ label = txn.get(label_key).decode()
94
+ # Normally, whitespace is removed from the labels.
95
+ if remove_whitespace:
96
+ label = ''.join(label.split())
97
+ # Normalize unicode composites (if any) and convert to compatible ASCII characters
98
+ if normalize_unicode:
99
+ label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode()
100
+ # Filter by length before removing unsupported characters. The original label might be too long.
101
+ if len(label) > max_label_len:
102
+ continue
103
+ label = charset_adapter(label)
104
+ # We filter out samples which don't contain any supported characters
105
+ if not label:
106
+ continue
107
+ # Filter images that are too small.
108
+ if min_image_dim > 0:
109
+ img_key = f'image-{index:09d}'.encode()
110
+ buf = io.BytesIO(txn.get(img_key))
111
+ w, h = Image.open(buf).size
112
+ if w < self.min_image_dim or h < self.min_image_dim:
113
+ continue
114
+ self.labels.append(label)
115
+ self.filtered_index_list.append(index)
116
+ return len(self.labels)
117
+
118
+ def __len__(self):
119
+ return self.num_samples
120
+
121
+ def __getitem__(self, index):
122
+ if self.unlabelled:
123
+ label = index
124
+ else:
125
+ label = self.labels[index]
126
+ index = self.filtered_index_list[index]
127
+
128
+ img_key = f'image-{index:09d}'.encode()
129
+ with self.env.begin() as txn:
130
+ imgbuf = txn.get(img_key)
131
+ buf = io.BytesIO(imgbuf)
132
+ img = Image.open(buf).convert('RGB')
133
+
134
+ if self.transform is not None:
135
+ img = self.transform(img)
136
+
137
+ return img, label
strhub/data/module.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pathlib import PurePath
17
+ from typing import Optional, Callable, Sequence, Tuple
18
+
19
+ import pytorch_lightning as pl
20
+ from torch.utils.data import DataLoader
21
+ from torchvision import transforms as T
22
+
23
+ from .dataset import build_tree_dataset, LmdbDataset
24
+
25
+
26
+ class SceneTextDataModule(pl.LightningDataModule):
27
+ TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80')
28
+ TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80')
29
+ TEST_NEW = ('ArT', 'COCOv1.4', 'Uber')
30
+ TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW))
31
+
32
+ def __init__(self, root_dir: str, train_dir: str, img_size: Sequence[int], max_label_length: int,
33
+ charset_train: str, charset_test: str, batch_size: int, num_workers: int, augment: bool,
34
+ remove_whitespace: bool = True, normalize_unicode: bool = True,
35
+ min_image_dim: int = 0, rotation: int = 0, collate_fn: Optional[Callable] = None):
36
+ super().__init__()
37
+ self.root_dir = root_dir
38
+ self.train_dir = train_dir
39
+ self.img_size = tuple(img_size)
40
+ self.max_label_length = max_label_length
41
+ self.charset_train = charset_train
42
+ self.charset_test = charset_test
43
+ self.batch_size = batch_size
44
+ self.num_workers = num_workers
45
+ self.augment = augment
46
+ self.remove_whitespace = remove_whitespace
47
+ self.normalize_unicode = normalize_unicode
48
+ self.min_image_dim = min_image_dim
49
+ self.rotation = rotation
50
+ self.collate_fn = collate_fn
51
+ self._train_dataset = None
52
+ self._val_dataset = None
53
+
54
+ @staticmethod
55
+ def get_transform(img_size: Tuple[int], augment: bool = False, rotation: int = 0):
56
+ transforms = []
57
+ if augment:
58
+ from .augment import rand_augment_transform
59
+ transforms.append(rand_augment_transform())
60
+ if rotation:
61
+ transforms.append(lambda img: img.rotate(rotation, expand=True))
62
+ transforms.extend([
63
+ T.Resize(img_size, T.InterpolationMode.BICUBIC),
64
+ T.ToTensor(),
65
+ T.Normalize(0.5, 0.5)
66
+ ])
67
+ return T.Compose(transforms)
68
+
69
+ @property
70
+ def train_dataset(self):
71
+ if self._train_dataset is None:
72
+ transform = self.get_transform(self.img_size, self.augment)
73
+ root = PurePath(self.root_dir, 'train', self.train_dir)
74
+ self._train_dataset = build_tree_dataset(root, self.charset_train, self.max_label_length,
75
+ self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
76
+ transform=transform)
77
+ return self._train_dataset
78
+
79
+ @property
80
+ def val_dataset(self):
81
+ if self._val_dataset is None:
82
+ transform = self.get_transform(self.img_size)
83
+ root = PurePath(self.root_dir, 'val')
84
+ self._val_dataset = build_tree_dataset(root, self.charset_test, self.max_label_length,
85
+ self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
86
+ transform=transform)
87
+ return self._val_dataset
88
+
89
+ def train_dataloader(self):
90
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
91
+ num_workers=self.num_workers, persistent_workers=self.num_workers > 0,
92
+ pin_memory=True, collate_fn=self.collate_fn)
93
+
94
+ def val_dataloader(self):
95
+ return DataLoader(self.val_dataset, batch_size=self.batch_size,
96
+ num_workers=self.num_workers, persistent_workers=self.num_workers > 0,
97
+ pin_memory=True, collate_fn=self.collate_fn)
98
+
99
+ def test_dataloaders(self, subset):
100
+ transform = self.get_transform(self.img_size, rotation=self.rotation)
101
+ root = PurePath(self.root_dir, 'test')
102
+ datasets = {s: LmdbDataset(str(root / s), self.charset_test, self.max_label_length,
103
+ self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
104
+ transform=transform) for s in subset}
105
+ return {k: DataLoader(v, batch_size=self.batch_size, num_workers=self.num_workers,
106
+ pin_memory=True, collate_fn=self.collate_fn)
107
+ for k, v in datasets.items()}
strhub/data/utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from abc import ABC, abstractmethod
18
+ from itertools import groupby
19
+ from typing import List, Optional, Tuple
20
+
21
+ import torch
22
+ from torch import Tensor
23
+ from torch.nn.utils.rnn import pad_sequence
24
+
25
+
26
+ class CharsetAdapter:
27
+ """Transforms labels according to the target charset."""
28
+
29
+ def __init__(self, target_charset) -> None:
30
+ super().__init__()
31
+ self.lowercase_only = target_charset == target_charset.lower()
32
+ self.uppercase_only = target_charset == target_charset.upper()
33
+ self.unsupported = f'[^{re.escape(target_charset)}]'
34
+
35
+ def __call__(self, label):
36
+ if self.lowercase_only:
37
+ label = label.lower()
38
+ elif self.uppercase_only:
39
+ label = label.upper()
40
+ # Remove unsupported characters
41
+ label = re.sub(self.unsupported, '', label)
42
+ return label
43
+
44
+
45
+ class BaseTokenizer(ABC):
46
+
47
+ def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
48
+ self._itos = specials_first + tuple(charset) + specials_last
49
+ self._stoi = {s: i for i, s in enumerate(self._itos)}
50
+
51
+ def __len__(self):
52
+ return len(self._itos)
53
+
54
+ def _tok2ids(self, tokens: str) -> List[int]:
55
+ return [self._stoi[s] for s in tokens]
56
+
57
+ def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
58
+ tokens = [self._itos[i] for i in token_ids]
59
+ return ''.join(tokens) if join else tokens
60
+
61
+ @abstractmethod
62
+ def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
63
+ """Encode a batch of labels to a representation suitable for the model.
64
+
65
+ Args:
66
+ labels: List of labels. Each can be of arbitrary length.
67
+ device: Create tensor on this device.
68
+
69
+ Returns:
70
+ Batched tensor representation padded to the max label length. Shape: N, L
71
+ """
72
+ raise NotImplementedError
73
+
74
+ @abstractmethod
75
+ def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
76
+ """Internal method which performs the necessary filtering prior to decoding."""
77
+ raise NotImplementedError
78
+
79
+ def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
80
+ """Decode a batch of token distributions.
81
+
82
+ Args:
83
+ token_dists: softmax probabilities over the token distribution. Shape: N, L, C
84
+ raw: return unprocessed labels (will return list of list of strings)
85
+
86
+ Returns:
87
+ list of string labels (arbitrary length) and
88
+ their corresponding sequence probabilities as a list of Tensors
89
+ """
90
+ batch_tokens = []
91
+ batch_probs = []
92
+ for dist in token_dists:
93
+ probs, ids = dist.max(-1) # greedy selection
94
+ if not raw:
95
+ probs, ids = self._filter(probs, ids)
96
+ tokens = self._ids2tok(ids, not raw)
97
+ batch_tokens.append(tokens)
98
+ batch_probs.append(probs)
99
+ return batch_tokens, batch_probs
100
+
101
+
102
+ class Tokenizer(BaseTokenizer):
103
+ BOS = '[B]'
104
+ EOS = '[E]'
105
+ PAD = '[P]'
106
+
107
+ def __init__(self, charset: str) -> None:
108
+ specials_first = (self.EOS,)
109
+ specials_last = (self.BOS, self.PAD)
110
+ super().__init__(charset, specials_first, specials_last)
111
+ self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last]
112
+
113
+ def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
114
+ batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
115
+ for y in labels]
116
+ return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
117
+
118
+ def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
119
+ ids = ids.tolist()
120
+ try:
121
+ eos_idx = ids.index(self.eos_id)
122
+ except ValueError:
123
+ eos_idx = len(ids) # Nothing to truncate.
124
+ # Truncate after EOS
125
+ ids = ids[:eos_idx]
126
+ probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
127
+ return probs, ids
128
+
129
+
130
+ class CTCTokenizer(BaseTokenizer):
131
+ BLANK = '[B]'
132
+
133
+ def __init__(self, charset: str) -> None:
134
+ # BLANK uses index == 0 by default
135
+ super().__init__(charset, specials_first=(self.BLANK,))
136
+ self.blank_id = self._stoi[self.BLANK]
137
+
138
+ def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
139
+ # We use a padded representation since we don't want to use CUDNN's CTC implementation
140
+ batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels]
141
+ return pad_sequence(batch, batch_first=True, padding_value=self.blank_id)
142
+
143
+ def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
144
+ # Best path decoding:
145
+ ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens
146
+ ids = [x for x in ids if x != self.blank_id] # Remove BLANKs
147
+ # `probs` is just pass-through since all positions are considered part of the path
148
+ return probs, ids
strhub/models/.ipynb_checkpoints/base-checkpoint.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from abc import ABC, abstractmethod
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, List
20
+
21
+ import pytorch_lightning as pl
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from nltk import edit_distance
25
+ from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
26
+ from timm.optim import create_optimizer_v2
27
+ from torch import Tensor
28
+ from torch.optim import Optimizer
29
+ from torch.optim.lr_scheduler import OneCycleLR
30
+
31
+ from strhub.data.utils import CharsetAdapter, CTCTokenizer, Tokenizer, BaseTokenizer
32
+
33
+
34
+ @dataclass
35
+ class BatchResult:
36
+ num_samples: int
37
+ correct: int
38
+ ned: float
39
+ confidence: float
40
+ label_length: int
41
+ loss: Tensor
42
+ loss_numel: int
43
+
44
+
45
+ class BaseSystem(pl.LightningModule, ABC):
46
+
47
+ def __init__(self, tokenizer: BaseTokenizer, charset_test: str,
48
+ batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None:
49
+ super().__init__()
50
+ self.tokenizer = tokenizer
51
+ self.charset_adapter = CharsetAdapter(charset_test)
52
+ self.batch_size = batch_size
53
+ self.lr = lr
54
+ self.warmup_pct = warmup_pct
55
+ self.weight_decay = weight_decay
56
+
57
+ @abstractmethod
58
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
59
+ """Inference
60
+
61
+ Args:
62
+ images: Batch of images. Shape: N, Ch, H, W
63
+ max_length: Max sequence length of the output. If None, will use default.
64
+
65
+ Returns:
66
+ logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
67
+ """
68
+ raise NotImplementedError
69
+
70
+ @abstractmethod
71
+ def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
72
+ """Like forward(), but also computes the loss (calls forward() internally).
73
+
74
+ Args:
75
+ images: Batch of images. Shape: N, Ch, H, W
76
+ labels: Text labels of the images
77
+
78
+ Returns:
79
+ logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
80
+ loss: mean loss for the batch
81
+ loss_numel: number of elements the loss was calculated from
82
+ """
83
+ raise NotImplementedError
84
+
85
+ def configure_optimizers(self):
86
+ agb = self.trainer.accumulate_grad_batches
87
+ # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
88
+ lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.
89
+ lr = lr_scale * self.lr
90
+ optim = create_optimizer_v2(self, 'adamw', lr, self.weight_decay)
91
+ sched = OneCycleLR(optim, lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct,
92
+ cycle_momentum=False)
93
+ return {'optimizer': optim, 'lr_scheduler': {'scheduler': sched, 'interval': 'step'}}
94
+
95
+ def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
96
+ optimizer.zero_grad(set_to_none=True)
97
+
98
+ def _eval_step(self, batch, validation: bool) -> Optional[STEP_OUTPUT]:
99
+ images, labels = batch
100
+
101
+ correct = 0
102
+ total = 0
103
+ ned = 0
104
+ confidence = 0
105
+ label_length = 0
106
+ if validation:
107
+ logits, loss, loss_numel = self.forward_logits_loss(images, labels)
108
+ else:
109
+ # At test-time, we shouldn't specify a max_label_length because the test-time charset used
110
+ # might be different from the train-time charset. max_label_length in eval_logits_loss() is computed
111
+ # based on the transformed label, which could be wrong if the actual gt label contains characters existing
112
+ # in the train-time charset but not in the test-time charset. For example, "aishahaleyes.blogspot.com"
113
+ # is exactly 25 characters, but if processed by CharsetAdapter for the 36-char set, it becomes 23 characters
114
+ # long only, which sets max_label_length = 23. This will cause the model prediction to be truncated.
115
+ logits = self.forward(images)
116
+ loss = loss_numel = None # Only used for validation; not needed at test-time.
117
+
118
+ probs = logits.softmax(-1)
119
+ preds, probs = self.tokenizer.decode(probs)
120
+ for pred, prob, gt in zip(preds, probs, labels):
121
+ confidence += prob.prod().item()
122
+ pred = self.charset_adapter(pred)
123
+ # Follow ICDAR 2019 definition of N.E.D.
124
+ ned += edit_distance(pred, gt) / max(len(pred), len(gt))
125
+ if pred == gt:
126
+ correct += 1
127
+ total += 1
128
+ label_length += len(pred)
129
+ return dict(output=BatchResult(total, correct, ned, confidence, label_length, loss, loss_numel))
130
+
131
+ @staticmethod
132
+ def _aggregate_results(outputs: EPOCH_OUTPUT) -> Tuple[float, float, float]:
133
+ if not outputs:
134
+ return 0., 0., 0.
135
+ total_loss = 0
136
+ total_loss_numel = 0
137
+ total_n_correct = 0
138
+ total_norm_ED = 0
139
+ total_size = 0
140
+ for result in outputs:
141
+ result = result['output']
142
+ total_loss += result.loss_numel * result.loss
143
+ total_loss_numel += result.loss_numel
144
+ total_n_correct += result.correct
145
+ total_norm_ED += result.ned
146
+ total_size += result.num_samples
147
+ acc = total_n_correct / total_size
148
+ ned = (1 - total_norm_ED / total_size)
149
+ loss = total_loss / total_loss_numel
150
+ return acc, ned, loss
151
+
152
+ def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
153
+ return self._eval_step(batch, True)
154
+
155
+ def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
156
+ acc, ned, loss = self._aggregate_results(outputs)
157
+ self.log('val_accuracy', 100 * acc, sync_dist=True)
158
+ self.log('val_NED', 100 * ned, sync_dist=True)
159
+ self.log('val_loss', loss, sync_dist=True)
160
+ self.log('hp_metric', acc, sync_dist=True)
161
+
162
+ def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
163
+ return self._eval_step(batch, False)
164
+
165
+
166
+ class CrossEntropySystem(BaseSystem):
167
+
168
+ def __init__(self, charset_train: str, charset_test: str,
169
+ batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None:
170
+ tokenizer = Tokenizer(charset_train)
171
+ super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
172
+ self.bos_id = tokenizer.bos_id
173
+ self.eos_id = tokenizer.eos_id
174
+ self.pad_id = tokenizer.pad_id
175
+
176
+ def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
177
+ targets = self.tokenizer.encode(labels, self.device)
178
+ targets = targets[:, 1:] # Discard <bos>
179
+ max_len = targets.shape[1] - 1 # exclude <eos> from count
180
+ logits = self.forward(images, max_len)
181
+ loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id)
182
+ loss_numel = (targets != self.pad_id).sum()
183
+ return logits, loss, loss_numel
184
+
185
+
186
+ class CTCSystem(BaseSystem):
187
+
188
+ def __init__(self, charset_train: str, charset_test: str,
189
+ batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None:
190
+ tokenizer = CTCTokenizer(charset_train)
191
+ super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
192
+ self.blank_id = tokenizer.blank_id
193
+
194
+ def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
195
+ targets = self.tokenizer.encode(labels, self.device)
196
+ logits = self.forward(images)
197
+ log_probs = logits.log_softmax(-1).transpose(0, 1) # swap batch and seq. dims
198
+ T, N, _ = log_probs.shape
199
+ input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long, device=self.device)
200
+ target_lengths = torch.as_tensor(list(map(len, labels)), dtype=torch.long, device=self.device)
201
+ loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_id, zero_infinity=True)
202
+ return logits, loss, N
strhub/models/.ipynb_checkpoints/modules-checkpoint.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Shared modules used by CRNN and TRBA"""
2
+ from torch import nn
3
+
4
+
5
+ class BidirectionalLSTM(nn.Module):
6
+ """Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py"""
7
+
8
+ def __init__(self, input_size, hidden_size, output_size):
9
+ super().__init__()
10
+ self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
11
+ self.linear = nn.Linear(hidden_size * 2, output_size)
12
+
13
+ def forward(self, input):
14
+ """
15
+ input : visual feature [batch_size x T x input_size], T = num_steps.
16
+ output : contextual feature [batch_size x T x output_size]
17
+ """
18
+ recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
19
+ output = self.linear(recurrent) # batch_size x T x output_size
20
+ return output
strhub/models/.ipynb_checkpoints/utils-checkpoint.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import PurePath
2
+ from typing import Sequence
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ import yaml
8
+
9
+
10
+ class InvalidModelError(RuntimeError):
11
+ """Exception raised for any model-related error (creation, loading)"""
12
+
13
+
14
+ _WEIGHTS_URL = {
15
+ 'parseq-tiny': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq_tiny-e7a21b54.pt',
16
+ 'parseq': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq-bb5792a6.pt',
17
+ 'abinet': 'https://github.com/baudm/parseq/releases/download/v1.0.0/abinet-1d1e373e.pt',
18
+ 'trba': 'https://github.com/baudm/parseq/releases/download/v1.0.0/trba-cfaed284.pt',
19
+ 'vitstr': 'https://github.com/baudm/parseq/releases/download/v1.0.0/vitstr-26d0fcf4.pt',
20
+ 'crnn': 'https://github.com/baudm/parseq/releases/download/v1.0.0/crnn-679d0e31.pt',
21
+ }
22
+
23
+
24
+ def _get_config(experiment: str, **kwargs):
25
+ """Emulates hydra config resolution"""
26
+ root = PurePath(__file__).parents[2]
27
+ with open(root / 'configs/main.yaml', 'r') as f:
28
+ config = yaml.load(f, yaml.Loader)['model']
29
+ with open(root / f'configs/charset/94_full.yaml', 'r') as f:
30
+ config.update(yaml.load(f, yaml.Loader)['model'])
31
+ with open(root / f'configs/experiment/{experiment}.yaml', 'r') as f:
32
+ exp = yaml.load(f, yaml.Loader)
33
+ # Apply base model config
34
+ model = exp['defaults'][0]['override /model']
35
+ with open(root / f'configs/model/{model}.yaml', 'r') as f:
36
+ config.update(yaml.load(f, yaml.Loader))
37
+ # Apply experiment config
38
+ if 'model' in exp:
39
+ config.update(exp['model'])
40
+ config.update(kwargs)
41
+ # Workaround for now: manually cast the lr to the correct type.
42
+ config['lr'] = float(config['lr'])
43
+ return config
44
+
45
+
46
+ def _get_model_class(key):
47
+ if 'abinet' in key:
48
+ from .abinet.system import ABINet as ModelClass
49
+ elif 'crnn' in key:
50
+ from .crnn.system import CRNN as ModelClass
51
+ elif 'parseq' in key:
52
+ from .parseq.system import PARSeq as ModelClass
53
+ elif 'trba' in key:
54
+ from .trba.system import TRBA as ModelClass
55
+ elif 'trbc' in key:
56
+ from .trba.system import TRBC as ModelClass
57
+ elif 'vitstr' in key:
58
+ from .vitstr.system import ViTSTR as ModelClass
59
+ else:
60
+ raise InvalidModelError("Unable to find model class for '{}'".format(key))
61
+ return ModelClass
62
+
63
+
64
+ def get_pretrained_weights(experiment):
65
+ try:
66
+ url = _WEIGHTS_URL[experiment]
67
+ except KeyError:
68
+ raise InvalidModelError("No pretrained weights found for '{}'".format(experiment)) from None
69
+ return torch.hub.load_state_dict_from_url(url=url, map_location='cpu', check_hash=True)
70
+
71
+
72
+ def create_model(experiment: str, pretrained: bool = False, **kwargs):
73
+ try:
74
+ config = _get_config(experiment, **kwargs)
75
+ except FileNotFoundError:
76
+ raise InvalidModelError("No configuration found for '{}'".format(experiment)) from None
77
+ ModelClass = _get_model_class(experiment)
78
+ model = ModelClass(**config)
79
+ if pretrained:
80
+ model.load_state_dict(get_pretrained_weights(experiment))
81
+ return model
82
+
83
+
84
+ def load_from_checkpoint(checkpoint_path: str, **kwargs):
85
+ if checkpoint_path.startswith('pretrained='):
86
+ model_id = checkpoint_path.split('=', maxsplit=1)[1]
87
+ model = create_model(model_id, True, **kwargs)
88
+ else:
89
+ ModelClass = _get_model_class(checkpoint_path)
90
+ model = ModelClass.load_from_checkpoint(checkpoint_path, **kwargs)
91
+ return model
92
+
93
+
94
+ def parse_model_args(args):
95
+ kwargs = {}
96
+ arg_types = {t.__name__: t for t in [int, float, str]}
97
+ arg_types['bool'] = lambda v: v.lower() == 'true' # special handling for bool
98
+ for arg in args:
99
+ name, value = arg.split('=', maxsplit=1)
100
+ name, arg_type = name.split(':', maxsplit=1)
101
+ kwargs[name] = arg_types[arg_type](value)
102
+ return kwargs
103
+
104
+
105
+ def init_weights(module: nn.Module, name: str = '', exclude: Sequence[str] = ()):
106
+ """Initialize the weights using the typical initialization schemes used in SOTA models."""
107
+ if any(map(name.startswith, exclude)):
108
+ return
109
+ if isinstance(module, nn.Linear):
110
+ nn.init.trunc_normal_(module.weight, std=.02)
111
+ if module.bias is not None:
112
+ nn.init.zeros_(module.bias)
113
+ elif isinstance(module, nn.Embedding):
114
+ nn.init.trunc_normal_(module.weight, std=.02)
115
+ if module.padding_idx is not None:
116
+ module.weight.data[module.padding_idx].zero_()
117
+ elif isinstance(module, nn.Conv2d):
118
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
119
+ if module.bias is not None:
120
+ nn.init.zeros_(module.bias)
121
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
122
+ nn.init.ones_(module.weight)
123
+ nn.init.zeros_(module.bias)
strhub/models/__init__.py ADDED
File without changes
strhub/models/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (141 Bytes). View file
 
strhub/models/__pycache__/base.cpython-37.pyc ADDED
Binary file (7.54 kB). View file
 
strhub/models/__pycache__/utils.cpython-37.pyc ADDED
Binary file (4.69 kB). View file
 
strhub/models/abinet/LICENSE ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ABINet for non-commercial purposes
2
+
3
+ Copyright (c) 2021, USTC
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
strhub/models/abinet/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""
2
+ Fang, Shancheng, Hongtao, Xie, Yuxin, Wang, Zhendong, Mao, and Yongdong, Zhang.
3
+ "Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition." .
4
+ In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 7098-7107).2021.
5
+
6
+ https://arxiv.org/abs/2103.06495
7
+
8
+ All source files, except `system.py`, are based on the implementation listed below,
9
+ and hence are released under the license of the original.
10
+
11
+ Source: https://github.com/FangShancheng/ABINet
12
+ License: 2-clause BSD License (see included LICENSE file)
13
+ """
strhub/models/abinet/attention.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .transformer import PositionalEncoding
5
+
6
+
7
+ class Attention(nn.Module):
8
+ def __init__(self, in_channels=512, max_length=25, n_feature=256):
9
+ super().__init__()
10
+ self.max_length = max_length
11
+
12
+ self.f0_embedding = nn.Embedding(max_length, in_channels)
13
+ self.w0 = nn.Linear(max_length, n_feature)
14
+ self.wv = nn.Linear(in_channels, in_channels)
15
+ self.we = nn.Linear(in_channels, max_length)
16
+
17
+ self.active = nn.Tanh()
18
+ self.softmax = nn.Softmax(dim=2)
19
+
20
+ def forward(self, enc_output):
21
+ enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2)
22
+ reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device)
23
+ reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S)
24
+ reading_order_embed = self.f0_embedding(reading_order) # b,25,512
25
+
26
+ t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256
27
+ t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512
28
+
29
+ attn = self.we(t) # b,256,25
30
+ attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256
31
+ g_output = torch.bmm(attn, enc_output) # b,25,512
32
+ return g_output, attn.view(*attn.shape[:2], 8, 32)
33
+
34
+
35
+ def encoder_layer(in_c, out_c, k=3, s=2, p=1):
36
+ return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
37
+ nn.BatchNorm2d(out_c),
38
+ nn.ReLU(True))
39
+
40
+
41
+ def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None):
42
+ align_corners = None if mode == 'nearest' else True
43
+ return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor,
44
+ mode=mode, align_corners=align_corners),
45
+ nn.Conv2d(in_c, out_c, k, s, p),
46
+ nn.BatchNorm2d(out_c),
47
+ nn.ReLU(True))
48
+
49
+
50
+ class PositionAttention(nn.Module):
51
+ def __init__(self, max_length, in_channels=512, num_channels=64,
52
+ h=8, w=32, mode='nearest', **kwargs):
53
+ super().__init__()
54
+ self.max_length = max_length
55
+ self.k_encoder = nn.Sequential(
56
+ encoder_layer(in_channels, num_channels, s=(1, 2)),
57
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
58
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
59
+ encoder_layer(num_channels, num_channels, s=(2, 2))
60
+ )
61
+ self.k_decoder = nn.Sequential(
62
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
63
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
64
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
65
+ decoder_layer(num_channels, in_channels, size=(h, w), mode=mode)
66
+ )
67
+
68
+ self.pos_encoder = PositionalEncoding(in_channels, dropout=0., max_len=max_length)
69
+ self.project = nn.Linear(in_channels, in_channels)
70
+
71
+ def forward(self, x):
72
+ N, E, H, W = x.size()
73
+ k, v = x, x # (N, E, H, W)
74
+
75
+ # calculate key vector
76
+ features = []
77
+ for i in range(0, len(self.k_encoder)):
78
+ k = self.k_encoder[i](k)
79
+ features.append(k)
80
+ for i in range(0, len(self.k_decoder) - 1):
81
+ k = self.k_decoder[i](k)
82
+ k = k + features[len(self.k_decoder) - 2 - i]
83
+ k = self.k_decoder[-1](k)
84
+
85
+ # calculate query vector
86
+ # TODO q=f(q,k)
87
+ zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E)
88
+ q = self.pos_encoder(zeros) # (T, N, E)
89
+ q = q.permute(1, 0, 2) # (N, T, E)
90
+ q = self.project(q) # (N, T, E)
91
+
92
+ # calculate attention
93
+ attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
94
+ attn_scores = attn_scores / (E ** 0.5)
95
+ attn_scores = torch.softmax(attn_scores, dim=-1)
96
+
97
+ v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
98
+ attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
99
+
100
+ return attn_vecs, attn_scores.view(N, -1, H, W)
strhub/models/abinet/backbone.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.nn import TransformerEncoderLayer, TransformerEncoder
3
+
4
+ from .resnet import resnet45
5
+ from .transformer import PositionalEncoding
6
+
7
+
8
+ class ResTranformer(nn.Module):
9
+ def __init__(self, d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', backbone_ln=2):
10
+ super().__init__()
11
+ self.resnet = resnet45()
12
+ self.pos_encoder = PositionalEncoding(d_model, max_len=8 * 32)
13
+ encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead,
14
+ dim_feedforward=d_inner, dropout=dropout, activation=activation)
15
+ self.transformer = TransformerEncoder(encoder_layer, backbone_ln)
16
+
17
+ def forward(self, images):
18
+ feature = self.resnet(images)
19
+ n, c, h, w = feature.shape
20
+ feature = feature.view(n, c, -1).permute(2, 0, 1)
21
+ feature = self.pos_encoder(feature)
22
+ feature = self.transformer(feature)
23
+ feature = feature.permute(1, 2, 0).view(n, c, h, w)
24
+ return feature
strhub/models/abinet/model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Model(nn.Module):
6
+
7
+ def __init__(self, dataset_max_length: int, null_label: int):
8
+ super().__init__()
9
+ self.max_length = dataset_max_length + 1 # additional stop token
10
+ self.null_label = null_label
11
+
12
+ def _get_length(self, logit, dim=-1):
13
+ """ Greed decoder to obtain length from logit"""
14
+ out = (logit.argmax(dim=-1) == self.null_label)
15
+ abn = out.any(dim)
16
+ out = ((out.cumsum(dim) == 1) & out).max(dim)[1]
17
+ out = out + 1 # additional end token
18
+ out = torch.where(abn, out, out.new_tensor(logit.shape[1], device=out.device))
19
+ return out
20
+
21
+ @staticmethod
22
+ def _get_padding_mask(length, max_length):
23
+ length = length.unsqueeze(-1)
24
+ grid = torch.arange(0, max_length, device=length.device).unsqueeze(0)
25
+ return grid >= length
26
+
27
+ @staticmethod
28
+ def _get_location_mask(sz, device=None):
29
+ mask = torch.eye(sz, device=device)
30
+ mask = mask.float().masked_fill(mask == 1, float('-inf'))
31
+ return mask
strhub/models/abinet/model_abinet_iter.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .model_alignment import BaseAlignment
5
+ from .model_language import BCNLanguage
6
+ from .model_vision import BaseVision
7
+
8
+
9
+ class ABINetIterModel(nn.Module):
10
+ def __init__(self, dataset_max_length, null_label, num_classes, iter_size=1,
11
+ d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu',
12
+ v_loss_weight=1., v_attention='position', v_attention_mode='nearest',
13
+ v_backbone='transformer', v_num_layers=2,
14
+ l_loss_weight=1., l_num_layers=4, l_detach=True, l_use_self_attn=False,
15
+ a_loss_weight=1.):
16
+ super().__init__()
17
+ self.iter_size = iter_size
18
+ self.vision = BaseVision(dataset_max_length, null_label, num_classes, v_attention, v_attention_mode,
19
+ v_loss_weight, d_model, nhead, d_inner, dropout, activation, v_backbone, v_num_layers)
20
+ self.language = BCNLanguage(dataset_max_length, null_label, num_classes, d_model, nhead, d_inner, dropout,
21
+ activation, l_num_layers, l_detach, l_use_self_attn, l_loss_weight)
22
+ self.alignment = BaseAlignment(dataset_max_length, null_label, num_classes, d_model, a_loss_weight)
23
+
24
+ def forward(self, images):
25
+ v_res = self.vision(images)
26
+ a_res = v_res
27
+ all_l_res, all_a_res = [], []
28
+ for _ in range(self.iter_size):
29
+ tokens = torch.softmax(a_res['logits'], dim=-1)
30
+ lengths = a_res['pt_lengths']
31
+ lengths.clamp_(2, self.language.max_length) # TODO:move to langauge model
32
+ l_res = self.language(tokens, lengths)
33
+ all_l_res.append(l_res)
34
+ a_res = self.alignment(l_res['feature'], v_res['feature'])
35
+ all_a_res.append(a_res)
36
+ if self.training:
37
+ return all_a_res, all_l_res, v_res
38
+ else:
39
+ return a_res, all_l_res[-1], v_res
strhub/models/abinet/model_alignment.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .model import Model
5
+
6
+
7
+ class BaseAlignment(Model):
8
+ def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, loss_weight=1.0):
9
+ super().__init__(dataset_max_length, null_label)
10
+ self.loss_weight = loss_weight
11
+ self.w_att = nn.Linear(2 * d_model, d_model)
12
+ self.cls = nn.Linear(d_model, num_classes)
13
+
14
+ def forward(self, l_feature, v_feature):
15
+ """
16
+ Args:
17
+ l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
18
+ v_feature: (N, T, E) shape the same as l_feature
19
+ """
20
+ f = torch.cat((l_feature, v_feature), dim=2)
21
+ f_att = torch.sigmoid(self.w_att(f))
22
+ output = f_att * v_feature + (1 - f_att) * l_feature
23
+
24
+ logits = self.cls(output) # (N, T, C)
25
+ pt_lengths = self._get_length(logits)
26
+
27
+ return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight': self.loss_weight,
28
+ 'name': 'alignment'}
strhub/models/abinet/model_language.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.nn import TransformerDecoder
3
+
4
+ from .model import Model
5
+ from .transformer import PositionalEncoding, TransformerDecoderLayer
6
+
7
+
8
+ class BCNLanguage(Model):
9
+ def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, nhead=8, d_inner=2048, dropout=0.1,
10
+ activation='relu', num_layers=4, detach=True, use_self_attn=False, loss_weight=1.0,
11
+ global_debug=False):
12
+ super().__init__(dataset_max_length, null_label)
13
+ self.detach = detach
14
+ self.loss_weight = loss_weight
15
+ self.proj = nn.Linear(num_classes, d_model, False)
16
+ self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length)
17
+ self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length)
18
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout,
19
+ activation, self_attn=use_self_attn, debug=global_debug)
20
+ self.model = TransformerDecoder(decoder_layer, num_layers)
21
+ self.cls = nn.Linear(d_model, num_classes)
22
+
23
+ def forward(self, tokens, lengths):
24
+ """
25
+ Args:
26
+ tokens: (N, T, C) where T is length, N is batch size and C is classes number
27
+ lengths: (N,)
28
+ """
29
+ if self.detach:
30
+ tokens = tokens.detach()
31
+ embed = self.proj(tokens) # (N, T, E)
32
+ embed = embed.permute(1, 0, 2) # (T, N, E)
33
+ embed = self.token_encoder(embed) # (T, N, E)
34
+ padding_mask = self._get_padding_mask(lengths, self.max_length)
35
+
36
+ zeros = embed.new_zeros(*embed.shape)
37
+ qeury = self.pos_encoder(zeros)
38
+ location_mask = self._get_location_mask(self.max_length, tokens.device)
39
+ output = self.model(qeury, embed,
40
+ tgt_key_padding_mask=padding_mask,
41
+ memory_mask=location_mask,
42
+ memory_key_padding_mask=padding_mask) # (T, N, E)
43
+ output = output.permute(1, 0, 2) # (N, T, E)
44
+
45
+ logits = self.cls(output) # (N, T, C)
46
+ pt_lengths = self._get_length(logits)
47
+
48
+ res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths,
49
+ 'loss_weight': self.loss_weight, 'name': 'language'}
50
+ return res
strhub/models/abinet/model_vision.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from .attention import PositionAttention, Attention
4
+ from .backbone import ResTranformer
5
+ from .model import Model
6
+ from .resnet import resnet45
7
+
8
+
9
+ class BaseVision(Model):
10
+ def __init__(self, dataset_max_length, null_label, num_classes,
11
+ attention='position', attention_mode='nearest', loss_weight=1.0,
12
+ d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu',
13
+ backbone='transformer', backbone_ln=2):
14
+ super().__init__(dataset_max_length, null_label)
15
+ self.loss_weight = loss_weight
16
+ self.out_channels = d_model
17
+
18
+ if backbone == 'transformer':
19
+ self.backbone = ResTranformer(d_model, nhead, d_inner, dropout, activation, backbone_ln)
20
+ else:
21
+ self.backbone = resnet45()
22
+
23
+ if attention == 'position':
24
+ self.attention = PositionAttention(
25
+ max_length=self.max_length,
26
+ mode=attention_mode
27
+ )
28
+ elif attention == 'attention':
29
+ self.attention = Attention(
30
+ max_length=self.max_length,
31
+ n_feature=8 * 32,
32
+ )
33
+ else:
34
+ raise ValueError(f'invalid attention: {attention}')
35
+
36
+ self.cls = nn.Linear(self.out_channels, num_classes)
37
+
38
+ def forward(self, images):
39
+ features = self.backbone(images) # (N, E, H, W)
40
+ attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W)
41
+ logits = self.cls(attn_vecs) # (N, T, C)
42
+ pt_lengths = self._get_length(logits)
43
+
44
+ return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths,
45
+ 'attn_scores': attn_scores, 'loss_weight': self.loss_weight, 'name': 'vision'}
strhub/models/abinet/resnet.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Callable
3
+
4
+ import torch.nn as nn
5
+ from torchvision.models import resnet
6
+
7
+
8
+ class BasicBlock(resnet.BasicBlock):
9
+
10
+ def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None,
11
+ groups: int = 1, base_width: int = 64, dilation: int = 1,
12
+ norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
13
+ super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)
14
+ self.conv1 = resnet.conv1x1(inplanes, planes)
15
+ self.conv2 = resnet.conv3x3(planes, planes, stride)
16
+
17
+
18
+ class ResNet(nn.Module):
19
+
20
+ def __init__(self, block, layers):
21
+ super().__init__()
22
+ self.inplanes = 32
23
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1,
24
+ bias=False)
25
+ self.bn1 = nn.BatchNorm2d(32)
26
+ self.relu = nn.ReLU(inplace=True)
27
+
28
+ self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
29
+ self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
30
+ self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
31
+ self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
32
+ self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
33
+
34
+ for m in self.modules():
35
+ if isinstance(m, nn.Conv2d):
36
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
37
+ m.weight.data.normal_(0, math.sqrt(2. / n))
38
+ elif isinstance(m, nn.BatchNorm2d):
39
+ m.weight.data.fill_(1)
40
+ m.bias.data.zero_()
41
+
42
+ def _make_layer(self, block, planes, blocks, stride=1):
43
+ downsample = None
44
+ if stride != 1 or self.inplanes != planes * block.expansion:
45
+ downsample = nn.Sequential(
46
+ nn.Conv2d(self.inplanes, planes * block.expansion,
47
+ kernel_size=1, stride=stride, bias=False),
48
+ nn.BatchNorm2d(planes * block.expansion),
49
+ )
50
+
51
+ layers = []
52
+ layers.append(block(self.inplanes, planes, stride, downsample))
53
+ self.inplanes = planes * block.expansion
54
+ for i in range(1, blocks):
55
+ layers.append(block(self.inplanes, planes))
56
+
57
+ return nn.Sequential(*layers)
58
+
59
+ def forward(self, x):
60
+ x = self.conv1(x)
61
+ x = self.bn1(x)
62
+ x = self.relu(x)
63
+ x = self.layer1(x)
64
+ x = self.layer2(x)
65
+ x = self.layer3(x)
66
+ x = self.layer4(x)
67
+ x = self.layer5(x)
68
+ return x
69
+
70
+
71
+ def resnet45():
72
+ return ResNet(BasicBlock, [3, 4, 6, 6, 3])
strhub/models/abinet/system.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import math
18
+ from typing import Any, Tuple, List, Optional
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import Tensor, nn
23
+ from torch.optim import AdamW
24
+ from torch.optim.lr_scheduler import OneCycleLR
25
+
26
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
27
+ from timm.optim.optim_factory import param_groups_weight_decay
28
+
29
+ from strhub.models.base import CrossEntropySystem
30
+ from strhub.models.utils import init_weights
31
+ from .model_abinet_iter import ABINetIterModel as Model
32
+
33
+ log = logging.getLogger(__name__)
34
+
35
+
36
+ class ABINet(CrossEntropySystem):
37
+
38
+ def __init__(self, charset_train: str, charset_test: str, max_label_length: int,
39
+ batch_size: int, lr: float, warmup_pct: float, weight_decay: float,
40
+ iter_size: int, d_model: int, nhead: int, d_inner: int, dropout: float, activation: str,
41
+ v_loss_weight: float, v_attention: str, v_attention_mode: str, v_backbone: str, v_num_layers: int,
42
+ l_loss_weight: float, l_num_layers: int, l_detach: bool, l_use_self_attn: bool,
43
+ l_lr: float, a_loss_weight: float, lm_only: bool = False, **kwargs) -> None:
44
+ super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
45
+ self.scheduler = None
46
+ self.save_hyperparameters()
47
+ self.max_label_length = max_label_length
48
+ self.num_classes = len(self.tokenizer) - 2 # We don't predict <bos> nor <pad>
49
+ self.model = Model(max_label_length, self.eos_id, self.num_classes, iter_size, d_model, nhead, d_inner,
50
+ dropout, activation, v_loss_weight, v_attention, v_attention_mode, v_backbone, v_num_layers,
51
+ l_loss_weight, l_num_layers, l_detach, l_use_self_attn, a_loss_weight)
52
+ self.model.apply(init_weights)
53
+ # FIXME: doesn't support resumption from checkpoint yet
54
+ self._reset_alignment = True
55
+ self._reset_optimizers = True
56
+ self.l_lr = l_lr
57
+ self.lm_only = lm_only
58
+ # Train LM only. Freeze other submodels.
59
+ if lm_only:
60
+ self.l_lr = lr # for tuning
61
+ self.model.vision.requires_grad_(False)
62
+ self.model.alignment.requires_grad_(False)
63
+
64
+ @property
65
+ def _pretraining(self):
66
+ # In the original work, VM was pretrained for 8 epochs while full model was trained for an additional 10 epochs.
67
+ total_steps = self.trainer.estimated_stepping_batches * self.trainer.accumulate_grad_batches
68
+ return self.global_step < (8 / (8 + 10)) * total_steps
69
+
70
+ @torch.jit.ignore
71
+ def no_weight_decay(self):
72
+ return {'model.language.proj.weight'}
73
+
74
+ def _add_weight_decay(self, model: nn.Module, skip_list=()):
75
+ if self.weight_decay:
76
+ return param_groups_weight_decay(model, self.weight_decay, skip_list)
77
+ else:
78
+ return [{'params': model.parameters()}]
79
+
80
+ def configure_optimizers(self):
81
+ agb = self.trainer.accumulate_grad_batches
82
+ # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
83
+ lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.
84
+ lr = lr_scale * self.lr
85
+ l_lr = lr_scale * self.l_lr
86
+ params = []
87
+ params.extend(self._add_weight_decay(self.model.vision))
88
+ params.extend(self._add_weight_decay(self.model.alignment))
89
+ # We use a different learning rate for the LM.
90
+ for p in self._add_weight_decay(self.model.language, ('proj.weight',)):
91
+ p['lr'] = l_lr
92
+ params.append(p)
93
+ max_lr = [p.get('lr', lr) for p in params]
94
+ optim = AdamW(params, lr)
95
+ self.scheduler = OneCycleLR(optim, max_lr, self.trainer.estimated_stepping_batches,
96
+ pct_start=self.warmup_pct, cycle_momentum=False)
97
+ return {'optimizer': optim, 'lr_scheduler': {'scheduler': self.scheduler, 'interval': 'step'}}
98
+
99
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
100
+ max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
101
+ logits = self.model.forward(images)[0]['logits']
102
+ return logits[:, :max_length + 1] # truncate
103
+
104
+ def calc_loss(self, targets, *res_lists) -> Tensor:
105
+ total_loss = 0
106
+ for res_list in res_lists:
107
+ loss = 0
108
+ if isinstance(res_list, dict):
109
+ res_list = [res_list]
110
+ for res in res_list:
111
+ logits = res['logits'].flatten(end_dim=1)
112
+ loss += F.cross_entropy(logits, targets.flatten(), ignore_index=self.pad_id)
113
+ loss /= len(res_list)
114
+ self.log('loss_' + res_list[0]['name'], loss)
115
+ total_loss += res_list[0]['loss_weight'] * loss
116
+ return total_loss
117
+
118
+ def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
119
+ if not self._pretraining and self._reset_optimizers:
120
+ log.info('Pretraining ends. Updating base LRs.')
121
+ self._reset_optimizers = False
122
+ # Make base_lr the same for all groups
123
+ base_lr = self.scheduler.base_lrs[0] # base_lr of group 0 - VM
124
+ self.scheduler.base_lrs = [base_lr] * len(self.scheduler.base_lrs)
125
+
126
+ def _prepare_inputs_and_targets(self, labels):
127
+ # Use dummy label to ensure sequence length is constant.
128
+ dummy = ['0' * self.max_label_length]
129
+ targets = self.tokenizer.encode(dummy + list(labels), self.device)[1:]
130
+ targets = targets[:, 1:] # remove <bos>. Unused here.
131
+ # Inputs are padded with eos_id
132
+ inputs = torch.where(targets == self.pad_id, self.eos_id, targets)
133
+ inputs = F.one_hot(inputs, self.num_classes).float()
134
+ lengths = torch.as_tensor(list(map(len, labels)), device=self.device) + 1 # +1 for eos
135
+ return inputs, lengths, targets
136
+
137
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
138
+ images, labels = batch
139
+ inputs, lengths, targets = self._prepare_inputs_and_targets(labels)
140
+ if self.lm_only:
141
+ l_res = self.model.language(inputs, lengths)
142
+ loss = self.calc_loss(targets, l_res)
143
+ # Pretrain submodels independently first
144
+ elif self._pretraining:
145
+ # Vision
146
+ v_res = self.model.vision(images)
147
+ # Language
148
+ l_res = self.model.language(inputs, lengths)
149
+ # We also train the alignment model to 'satisfy' DDP requirements (all parameters should be used).
150
+ # We'll reset its parameters prior to joint training.
151
+ a_res = self.model.alignment(l_res['feature'].detach(), v_res['feature'].detach())
152
+ loss = self.calc_loss(targets, v_res, l_res, a_res)
153
+ else:
154
+ # Reset alignment model's parameters once prior to full model training.
155
+ if self._reset_alignment:
156
+ log.info('Pretraining ends. Resetting alignment model.')
157
+ self._reset_alignment = False
158
+ self.model.alignment.apply(init_weights)
159
+ all_a_res, all_l_res, v_res = self.model.forward(images)
160
+ loss = self.calc_loss(targets, v_res, all_l_res, all_a_res)
161
+ self.log('loss', loss)
162
+ return loss
163
+
164
+ def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
165
+ if self.lm_only:
166
+ inputs, lengths, targets = self._prepare_inputs_and_targets(labels)
167
+ l_res = self.model.language(inputs, lengths)
168
+ loss = self.calc_loss(targets, l_res)
169
+ loss_numel = (targets != self.pad_id).sum()
170
+ return l_res['logits'], loss, loss_numel
171
+ else:
172
+ return super().forward_logits_loss(images, labels)
strhub/models/abinet/transformer.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from torch.nn.modules.transformer import _get_activation_fn
7
+
8
+
9
+ class TransformerDecoderLayer(nn.Module):
10
+ r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
11
+ This standard decoder layer is based on the paper "Attention Is All You Need".
12
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
13
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
14
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
15
+ in a different way during application.
16
+
17
+ Args:
18
+ d_model: the number of expected features in the input (required).
19
+ nhead: the number of heads in the multiheadattention models (required).
20
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
21
+ dropout: the dropout value (default=0.1).
22
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
23
+
24
+ Examples::
25
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
26
+ >>> memory = torch.rand(10, 32, 512)
27
+ >>> tgt = torch.rand(20, 32, 512)
28
+ >>> out = decoder_layer(tgt, memory)
29
+ """
30
+
31
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
32
+ activation="relu", self_attn=True, siamese=False, debug=False):
33
+ super().__init__()
34
+ self.has_self_attn, self.siamese = self_attn, siamese
35
+ self.debug = debug
36
+ if self.has_self_attn:
37
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
38
+ self.norm1 = nn.LayerNorm(d_model)
39
+ self.dropout1 = nn.Dropout(dropout)
40
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
41
+ # Implementation of Feedforward model
42
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
43
+ self.dropout = nn.Dropout(dropout)
44
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
45
+
46
+ self.norm2 = nn.LayerNorm(d_model)
47
+ self.norm3 = nn.LayerNorm(d_model)
48
+ self.dropout2 = nn.Dropout(dropout)
49
+ self.dropout3 = nn.Dropout(dropout)
50
+ if self.siamese:
51
+ self.multihead_attn2 = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
52
+
53
+ self.activation = _get_activation_fn(activation)
54
+
55
+ def __setstate__(self, state):
56
+ if 'activation' not in state:
57
+ state['activation'] = F.relu
58
+ super().__setstate__(state)
59
+
60
+ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
61
+ tgt_key_padding_mask=None, memory_key_padding_mask=None,
62
+ memory2=None, memory_mask2=None, memory_key_padding_mask2=None):
63
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
64
+ r"""Pass the inputs (and mask) through the decoder layer.
65
+
66
+ Args:
67
+ tgt: the sequence to the decoder layer (required).
68
+ memory: the sequence from the last layer of the encoder (required).
69
+ tgt_mask: the mask for the tgt sequence (optional).
70
+ memory_mask: the mask for the memory sequence (optional).
71
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
72
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
73
+
74
+ Shape:
75
+ see the docs in Transformer class.
76
+ """
77
+ if self.has_self_attn:
78
+ tgt2, attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
79
+ key_padding_mask=tgt_key_padding_mask)
80
+ tgt = tgt + self.dropout1(tgt2)
81
+ tgt = self.norm1(tgt)
82
+ if self.debug: self.attn = attn
83
+ tgt2, attn2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
84
+ key_padding_mask=memory_key_padding_mask)
85
+ if self.debug: self.attn2 = attn2
86
+
87
+ if self.siamese:
88
+ tgt3, attn3 = self.multihead_attn2(tgt, memory2, memory2, attn_mask=memory_mask2,
89
+ key_padding_mask=memory_key_padding_mask2)
90
+ tgt = tgt + self.dropout2(tgt3)
91
+ if self.debug: self.attn3 = attn3
92
+
93
+ tgt = tgt + self.dropout2(tgt2)
94
+ tgt = self.norm2(tgt)
95
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
96
+ tgt = tgt + self.dropout3(tgt2)
97
+ tgt = self.norm3(tgt)
98
+
99
+ return tgt
100
+
101
+
102
+ class PositionalEncoding(nn.Module):
103
+ r"""Inject some information about the relative or absolute position of the tokens
104
+ in the sequence. The positional encodings have the same dimension as
105
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
106
+ functions of different frequencies.
107
+ .. math::
108
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
109
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
110
+ \text{where pos is the word position and i is the embed idx)
111
+ Args:
112
+ d_model: the embed dim (required).
113
+ dropout: the dropout value (default=0.1).
114
+ max_len: the max. length of the incoming sequence (default=5000).
115
+ Examples:
116
+ >>> pos_encoder = PositionalEncoding(d_model)
117
+ """
118
+
119
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
120
+ super().__init__()
121
+ self.dropout = nn.Dropout(p=dropout)
122
+
123
+ pe = torch.zeros(max_len, d_model)
124
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
125
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
126
+ pe[:, 0::2] = torch.sin(position * div_term)
127
+ pe[:, 1::2] = torch.cos(position * div_term)
128
+ pe = pe.unsqueeze(0).transpose(0, 1)
129
+ self.register_buffer('pe', pe)
130
+
131
+ def forward(self, x):
132
+ r"""Inputs of forward function
133
+ Args:
134
+ x: the sequence fed to the positional encoder model (required).
135
+ Shape:
136
+ x: [sequence length, batch size, embed dim]
137
+ output: [sequence length, batch size, embed dim]
138
+ Examples:
139
+ >>> output = pos_encoder(x)
140
+ """
141
+
142
+ x = x + self.pe[:x.size(0), :]
143
+ return self.dropout(x)
strhub/models/base.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from abc import ABC, abstractmethod
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, List
20
+
21
+ import pytorch_lightning as pl
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from nltk import edit_distance
25
+ from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
26
+ from timm.optim import create_optimizer_v2
27
+ from torch import Tensor
28
+ from torch.optim import Optimizer
29
+ from torch.optim.lr_scheduler import OneCycleLR
30
+
31
+ from strhub.data.utils import CharsetAdapter, CTCTokenizer, Tokenizer, BaseTokenizer
32
+
33
+
34
+ @dataclass
35
+ class BatchResult:
36
+ num_samples: int
37
+ correct: int
38
+ ned: float
39
+ confidence: float
40
+ label_length: int
41
+ loss: Tensor
42
+ loss_numel: int
43
+
44
+
45
+ class BaseSystem(pl.LightningModule, ABC):
46
+
47
+ def __init__(self, tokenizer: BaseTokenizer, charset_test: str,
48
+ batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None:
49
+ super().__init__()
50
+ self.tokenizer = tokenizer
51
+ self.charset_adapter = CharsetAdapter(charset_test)
52
+ self.batch_size = batch_size
53
+ self.lr = lr
54
+ self.warmup_pct = warmup_pct
55
+ self.weight_decay = weight_decay
56
+
57
+ @abstractmethod
58
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
59
+ """Inference
60
+
61
+ Args:
62
+ images: Batch of images. Shape: N, Ch, H, W
63
+ max_length: Max sequence length of the output. If None, will use default.
64
+
65
+ Returns:
66
+ logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
67
+ """
68
+ raise NotImplementedError
69
+
70
+ @abstractmethod
71
+ def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
72
+ """Like forward(), but also computes the loss (calls forward() internally).
73
+
74
+ Args:
75
+ images: Batch of images. Shape: N, Ch, H, W
76
+ labels: Text labels of the images
77
+
78
+ Returns:
79
+ logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
80
+ loss: mean loss for the batch
81
+ loss_numel: number of elements the loss was calculated from
82
+ """
83
+ raise NotImplementedError
84
+
85
+ def configure_optimizers(self):
86
+ agb = self.trainer.accumulate_grad_batches
87
+ # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
88
+ lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.
89
+ lr = lr_scale * self.lr
90
+ optim = create_optimizer_v2(self, 'adamw', lr, self.weight_decay)
91
+ sched = OneCycleLR(optim, lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct,
92
+ cycle_momentum=False)
93
+ return {'optimizer': optim, 'lr_scheduler': {'scheduler': sched, 'interval': 'step'}}
94
+
95
+ def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
96
+ optimizer.zero_grad(set_to_none=True)
97
+
98
+ def _eval_step(self, batch, validation: bool) -> Optional[STEP_OUTPUT]:
99
+ images, labels = batch
100
+
101
+ correct = 0
102
+ total = 0
103
+ ned = 0
104
+ confidence = 0
105
+ label_length = 0
106
+ if validation:
107
+ logits, loss, loss_numel = self.forward_logits_loss(images, labels)
108
+ else:
109
+ # At test-time, we shouldn't specify a max_label_length because the test-time charset used
110
+ # might be different from the train-time charset. max_label_length in eval_logits_loss() is computed
111
+ # based on the transformed label, which could be wrong if the actual gt label contains characters existing
112
+ # in the train-time charset but not in the test-time charset. For example, "aishahaleyes.blogspot.com"
113
+ # is exactly 25 characters, but if processed by CharsetAdapter for the 36-char set, it becomes 23 characters
114
+ # long only, which sets max_label_length = 23. This will cause the model prediction to be truncated.
115
+ logits = self.forward(images)
116
+ loss = loss_numel = None # Only used for validation; not needed at test-time.
117
+
118
+ probs = logits.softmax(-1)
119
+ preds, probs = self.tokenizer.decode(probs)
120
+ for pred, prob, gt in zip(preds, probs, labels):
121
+ confidence += prob.prod().item()
122
+ pred = self.charset_adapter(pred)
123
+ # Follow ICDAR 2019 definition of N.E.D.
124
+ ned += edit_distance(pred, gt) / max(len(pred), len(gt))
125
+ if pred == gt:
126
+ correct += 1
127
+ total += 1
128
+ label_length += len(pred)
129
+ return dict(output=BatchResult(total, correct, ned, confidence, label_length, loss, loss_numel))
130
+
131
+ @staticmethod
132
+ def _aggregate_results(outputs: EPOCH_OUTPUT) -> Tuple[float, float, float]:
133
+ if not outputs:
134
+ return 0., 0., 0.
135
+ total_loss = 0
136
+ total_loss_numel = 0
137
+ total_n_correct = 0
138
+ total_norm_ED = 0
139
+ total_size = 0
140
+ for result in outputs:
141
+ result = result['output']
142
+ total_loss += result.loss_numel * result.loss
143
+ total_loss_numel += result.loss_numel
144
+ total_n_correct += result.correct
145
+ total_norm_ED += result.ned
146
+ total_size += result.num_samples
147
+ acc = total_n_correct / total_size
148
+ ned = (1 - total_norm_ED / total_size)
149
+ loss = total_loss / total_loss_numel
150
+ return acc, ned, loss
151
+
152
+ def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
153
+ return self._eval_step(batch, True)
154
+
155
+ def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
156
+ acc, ned, loss = self._aggregate_results(outputs)
157
+ self.log('val_accuracy', 100 * acc, sync_dist=True)
158
+ self.log('val_NED', 100 * ned, sync_dist=True)
159
+ self.log('val_loss', loss, sync_dist=True)
160
+ self.log('hp_metric', acc, sync_dist=True)
161
+
162
+ def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
163
+ return self._eval_step(batch, False)
164
+
165
+
166
+ class CrossEntropySystem(BaseSystem):
167
+
168
+ def __init__(self, charset_train: str, charset_test: str,
169
+ batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None:
170
+ tokenizer = Tokenizer(charset_train)
171
+ super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
172
+ self.bos_id = tokenizer.bos_id
173
+ self.eos_id = tokenizer.eos_id
174
+ self.pad_id = tokenizer.pad_id
175
+
176
+ def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
177
+ targets = self.tokenizer.encode(labels, self.device)
178
+ targets = targets[:, 1:] # Discard <bos>
179
+ max_len = targets.shape[1] - 1 # exclude <eos> from count
180
+ logits = self.forward(images, max_len)
181
+ loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id)
182
+ loss_numel = (targets != self.pad_id).sum()
183
+ return logits, loss, loss_numel
184
+
185
+
186
+ class CTCSystem(BaseSystem):
187
+
188
+ def __init__(self, charset_train: str, charset_test: str,
189
+ batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None:
190
+ tokenizer = CTCTokenizer(charset_train)
191
+ super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
192
+ self.blank_id = tokenizer.blank_id
193
+
194
+ def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
195
+ targets = self.tokenizer.encode(labels, self.device)
196
+ logits = self.forward(images)
197
+ log_probs = logits.log_softmax(-1).transpose(0, 1) # swap batch and seq. dims
198
+ T, N, _ = log_probs.shape
199
+ input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long, device=self.device)
200
+ target_lengths = torch.as_tensor(list(map(len, labels)), dtype=torch.long, device=self.device)
201
+ loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_id, zero_infinity=True)
202
+ return logits, loss, N
strhub/models/crnn/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2017 Jieru Mei <meijieru@gmail.com>
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
strhub/models/crnn/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""
2
+ Shi, Baoguang, Xiang Bai, and Cong Yao.
3
+ "An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition."
4
+ IEEE transactions on pattern analysis and machine intelligence 39, no. 11 (2016): 2298-2304.
5
+
6
+ https://arxiv.org/abs/1507.05717
7
+
8
+ All source files, except `system.py`, are based on the implementation listed below,
9
+ and hence are released under the license of the original.
10
+
11
+ Source: https://github.com/meijieru/crnn.pytorch
12
+ License: MIT License (see included LICENSE file)
13
+ """
strhub/models/crnn/model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from strhub.models.modules import BidirectionalLSTM
4
+
5
+
6
+ class CRNN(nn.Module):
7
+
8
+ def __init__(self, img_h, nc, nclass, nh, leaky_relu=False):
9
+ super().__init__()
10
+ assert img_h % 16 == 0, 'img_h has to be a multiple of 16'
11
+
12
+ ks = [3, 3, 3, 3, 3, 3, 2]
13
+ ps = [1, 1, 1, 1, 1, 1, 0]
14
+ ss = [1, 1, 1, 1, 1, 1, 1]
15
+ nm = [64, 128, 256, 256, 512, 512, 512]
16
+
17
+ cnn = nn.Sequential()
18
+
19
+ def convRelu(i, batchNormalization=False):
20
+ nIn = nc if i == 0 else nm[i - 1]
21
+ nOut = nm[i]
22
+ cnn.add_module('conv{0}'.format(i),
23
+ nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i], bias=not batchNormalization))
24
+ if batchNormalization:
25
+ cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
26
+ if leaky_relu:
27
+ cnn.add_module('relu{0}'.format(i),
28
+ nn.LeakyReLU(0.2, inplace=True))
29
+ else:
30
+ cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
31
+
32
+ convRelu(0)
33
+ cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
34
+ convRelu(1)
35
+ cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
36
+ convRelu(2, True)
37
+ convRelu(3)
38
+ cnn.add_module('pooling{0}'.format(2),
39
+ nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
40
+ convRelu(4, True)
41
+ convRelu(5)
42
+ cnn.add_module('pooling{0}'.format(3),
43
+ nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
44
+ convRelu(6, True) # 512x1x16
45
+
46
+ self.cnn = cnn
47
+ self.rnn = nn.Sequential(
48
+ BidirectionalLSTM(512, nh, nh),
49
+ BidirectionalLSTM(nh, nh, nclass))
50
+
51
+ def forward(self, input):
52
+ # conv features
53
+ conv = self.cnn(input)
54
+ b, c, h, w = conv.size()
55
+ assert h == 1, 'the height of conv must be 1'
56
+ conv = conv.squeeze(2)
57
+ conv = conv.transpose(1, 2) # [b, w, c]
58
+
59
+ # rnn features
60
+ output = self.rnn(conv)
61
+
62
+ return output
strhub/models/crnn/system.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Sequence, Optional
17
+
18
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
19
+ from torch import Tensor
20
+
21
+ from strhub.models.base import CTCSystem
22
+ from strhub.models.utils import init_weights
23
+ from .model import CRNN as Model
24
+
25
+
26
+ class CRNN(CTCSystem):
27
+
28
+ def __init__(self, charset_train: str, charset_test: str, max_label_length: int,
29
+ batch_size: int, lr: float, warmup_pct: float, weight_decay: float,
30
+ img_size: Sequence[int], hidden_size: int, leaky_relu: bool, **kwargs) -> None:
31
+ super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
32
+ self.save_hyperparameters()
33
+ self.model = Model(img_size[0], 3, len(self.tokenizer), hidden_size, leaky_relu)
34
+ self.model.apply(init_weights)
35
+
36
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
37
+ return self.model.forward(images)
38
+
39
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
40
+ images, labels = batch
41
+ loss = self.forward_logits_loss(images, labels)[1]
42
+ self.log('loss', loss)
43
+ return loss
strhub/models/modules.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Shared modules used by CRNN and TRBA"""
2
+ from torch import nn
3
+
4
+
5
+ class BidirectionalLSTM(nn.Module):
6
+ """Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py"""
7
+
8
+ def __init__(self, input_size, hidden_size, output_size):
9
+ super().__init__()
10
+ self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
11
+ self.linear = nn.Linear(hidden_size * 2, output_size)
12
+
13
+ def forward(self, input):
14
+ """
15
+ input : visual feature [batch_size x T x input_size], T = num_steps.
16
+ output : contextual feature [batch_size x T x output_size]
17
+ """
18
+ recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
19
+ output = self.linear(recurrent) # batch_size x T x output_size
20
+ return output
strhub/models/parseq/__init__.py ADDED
File without changes
strhub/models/parseq/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (148 Bytes). View file
 
strhub/models/parseq/__pycache__/modules.cpython-37.pyc ADDED
Binary file (4.96 kB). View file
 
strhub/models/parseq/__pycache__/system.cpython-37.pyc ADDED
Binary file (7.62 kB). View file
 
strhub/models/parseq/modules.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Optional
18
+
19
+ import torch
20
+ from torch import nn as nn, Tensor
21
+ from torch.nn import functional as F
22
+ from torch.nn.modules import transformer
23
+
24
+ from timm.models.vision_transformer import VisionTransformer, PatchEmbed
25
+
26
+
27
+ class DecoderLayer(nn.Module):
28
+ """A Transformer decoder layer supporting two-stream attention (XLNet)
29
+ This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
30
+
31
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu',
32
+ layer_norm_eps=1e-5):
33
+ super().__init__()
34
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
35
+ self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
36
+ # Implementation of Feedforward model
37
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
38
+ self.dropout = nn.Dropout(dropout)
39
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
40
+
41
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
42
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
43
+ self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps)
44
+ self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps)
45
+ self.dropout1 = nn.Dropout(dropout)
46
+ self.dropout2 = nn.Dropout(dropout)
47
+ self.dropout3 = nn.Dropout(dropout)
48
+
49
+ self.activation = transformer._get_activation_fn(activation)
50
+
51
+ def __setstate__(self, state):
52
+ if 'activation' not in state:
53
+ state['activation'] = F.gelu
54
+ super().__setstate__(state)
55
+
56
+ def forward_stream(self, tgt: Tensor, tgt_norm: Tensor, tgt_kv: Tensor, memory: Tensor, tgt_mask: Optional[Tensor],
57
+ tgt_key_padding_mask: Optional[Tensor]):
58
+ """Forward pass for a single stream (i.e. content or query)
59
+ tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
60
+ Both tgt_kv and memory are expected to be LayerNorm'd too.
61
+ memory is LayerNorm'd by ViT.
62
+ """
63
+ tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask,
64
+ key_padding_mask=tgt_key_padding_mask)
65
+ tgt = tgt + self.dropout1(tgt2)
66
+
67
+ tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
68
+ tgt = tgt + self.dropout2(tgt2)
69
+
70
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
71
+ tgt = tgt + self.dropout3(tgt2)
72
+ return tgt, sa_weights, ca_weights
73
+
74
+ def forward(self, query, content, memory, query_mask: Optional[Tensor] = None, content_mask: Optional[Tensor] = None,
75
+ content_key_padding_mask: Optional[Tensor] = None, update_content: bool = True):
76
+ query_norm = self.norm_q(query)
77
+ content_norm = self.norm_c(content)
78
+ query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0]
79
+ if update_content:
80
+ content = self.forward_stream(content, content_norm, content_norm, memory, content_mask,
81
+ content_key_padding_mask)[0]
82
+ return query, content
83
+
84
+
85
+ class Decoder(nn.Module):
86
+ __constants__ = ['norm']
87
+
88
+ def __init__(self, decoder_layer, num_layers, norm):
89
+ super().__init__()
90
+ self.layers = transformer._get_clones(decoder_layer, num_layers)
91
+ self.num_layers = num_layers
92
+ self.norm = norm
93
+
94
+ def forward(self, query, content, memory, query_mask: Optional[Tensor] = None, content_mask: Optional[Tensor] = None,
95
+ content_key_padding_mask: Optional[Tensor] = None):
96
+ for i, mod in enumerate(self.layers):
97
+ last = i == len(self.layers) - 1
98
+ query, content = mod(query, content, memory, query_mask, content_mask, content_key_padding_mask,
99
+ update_content=not last)
100
+ query = self.norm(query)
101
+ return query
102
+
103
+
104
+ class Encoder(VisionTransformer):
105
+
106
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
107
+ qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed):
108
+ super().__init__(img_size, patch_size, in_chans, embed_dim=embed_dim, depth=depth, num_heads=num_heads,
109
+ mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
110
+ drop_path_rate=drop_path_rate, embed_layer=embed_layer,
111
+ num_classes=0, global_pool='', class_token=False) # these disable the classifier head
112
+
113
+ def forward(self, x):
114
+ # Return all tokens
115
+ return self.forward_features(x)
116
+
117
+
118
+ class TokenEmbedding(nn.Module):
119
+
120
+ def __init__(self, charset_size: int, embed_dim: int):
121
+ super().__init__()
122
+ self.embedding = nn.Embedding(charset_size, embed_dim)
123
+ self.embed_dim = embed_dim
124
+
125
+ def forward(self, tokens: torch.Tensor):
126
+ return math.sqrt(self.embed_dim) * self.embedding(tokens)