Isaacgonzales
commited on
Commit
•
d02e83e
1
Parent(s):
d59ff1a
add model
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +39 -0
- model.py +52 -0
- post.py +42 -0
- requirements.txt +7 -0
- strhub/__init__.py +0 -0
- strhub/__pycache__/__init__.cpython-37.pyc +0 -0
- strhub/data/.ipynb_checkpoints/dataset-checkpoint.py +137 -0
- strhub/data/.ipynb_checkpoints/module-checkpoint.py +107 -0
- strhub/data/__init__.py +0 -0
- strhub/data/__pycache__/__init__.cpython-37.pyc +0 -0
- strhub/data/__pycache__/aa_overrides.cpython-37.pyc +0 -0
- strhub/data/__pycache__/augment.cpython-37.pyc +0 -0
- strhub/data/__pycache__/dataset.cpython-37.pyc +0 -0
- strhub/data/__pycache__/module.cpython-37.pyc +0 -0
- strhub/data/__pycache__/utils.cpython-37.pyc +0 -0
- strhub/data/aa_overrides.py +46 -0
- strhub/data/augment.py +111 -0
- strhub/data/dataset.py +137 -0
- strhub/data/module.py +107 -0
- strhub/data/utils.py +148 -0
- strhub/models/.ipynb_checkpoints/base-checkpoint.py +202 -0
- strhub/models/.ipynb_checkpoints/modules-checkpoint.py +20 -0
- strhub/models/.ipynb_checkpoints/utils-checkpoint.py +123 -0
- strhub/models/__init__.py +0 -0
- strhub/models/__pycache__/__init__.cpython-37.pyc +0 -0
- strhub/models/__pycache__/base.cpython-37.pyc +0 -0
- strhub/models/__pycache__/utils.cpython-37.pyc +0 -0
- strhub/models/abinet/LICENSE +25 -0
- strhub/models/abinet/__init__.py +13 -0
- strhub/models/abinet/attention.py +100 -0
- strhub/models/abinet/backbone.py +24 -0
- strhub/models/abinet/model.py +31 -0
- strhub/models/abinet/model_abinet_iter.py +39 -0
- strhub/models/abinet/model_alignment.py +28 -0
- strhub/models/abinet/model_language.py +50 -0
- strhub/models/abinet/model_vision.py +45 -0
- strhub/models/abinet/resnet.py +72 -0
- strhub/models/abinet/system.py +172 -0
- strhub/models/abinet/transformer.py +143 -0
- strhub/models/base.py +202 -0
- strhub/models/crnn/LICENSE +21 -0
- strhub/models/crnn/__init__.py +13 -0
- strhub/models/crnn/model.py +62 -0
- strhub/models/crnn/system.py +43 -0
- strhub/models/modules.py +20 -0
- strhub/models/parseq/__init__.py +0 -0
- strhub/models/parseq/__pycache__/__init__.cpython-37.pyc +0 -0
- strhub/models/parseq/__pycache__/modules.cpython-37.pyc +0 -0
- strhub/models/parseq/__pycache__/system.cpython-37.pyc +0 -0
- strhub/models/parseq/modules.py +126 -0
app.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
import urllib.request
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
import urllib.request
|
8 |
+
import zipfile
|
9 |
+
|
10 |
+
sys.path.append(".")
|
11 |
+
|
12 |
+
from model import prediction
|
13 |
+
|
14 |
+
gr.close_all()
|
15 |
+
|
16 |
+
https://storage.googleapis.com/models-gradio/products/products.zip
|
17 |
+
|
18 |
+
urllib.request.urlretrieve("https://storage.googleapis.com/models-gradio/products/products.zip")
|
19 |
+
|
20 |
+
|
21 |
+
with zipfile.ZipFile("products.zip", 'r') as zip_ref:
|
22 |
+
zip_ref.extractall()
|
23 |
+
|
24 |
+
def predict(img):
|
25 |
+
name_image = img.split("/")[-1]
|
26 |
+
|
27 |
+
prediction_img, text = prediction(img)
|
28 |
+
|
29 |
+
return str(text), prediction_img,
|
30 |
+
|
31 |
+
|
32 |
+
sample_images = ["dataset/" + name for name in os.listdir("dataset")]
|
33 |
+
|
34 |
+
|
35 |
+
gr.Interface(fn=predict,
|
36 |
+
inputs=[gr.Image(label="image à tester" ,type="filepath")],
|
37 |
+
outputs=[gr.Textbox(label="analyse"), gr.Image(label ="résultat") ],
|
38 |
+
css="footer {visibility: hidden} body}, .gradio-container {background-color: white}",
|
39 |
+
examples=sample_images).launch(server_name="0.0.0.0", share=False)
|
model.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from strhub.data.module import SceneTextDataModule
|
3 |
+
from strhub.models.utils import load_from_checkpoint
|
4 |
+
from post import filter_mask
|
5 |
+
import segmentation_models_pytorch as smp
|
6 |
+
import albumentations as albu
|
7 |
+
from torchvision import transforms
|
8 |
+
from PIL import Image
|
9 |
+
import torch
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
model_recog = load_from_checkpoint("weights/parseq/last.ckpt").eval().to("cpu")
|
13 |
+
img_transform = SceneTextDataModule.get_transform(model_recog.hparams.img_size)
|
14 |
+
|
15 |
+
model = torch.load('weights/best_model.pth').to("cpu")
|
16 |
+
model.eval()
|
17 |
+
model.float()
|
18 |
+
|
19 |
+
SHAPE_X = 384
|
20 |
+
SHAPE_Y = 384
|
21 |
+
|
22 |
+
|
23 |
+
def prediction(image_path):
|
24 |
+
image = cv2.imread(image_path)
|
25 |
+
image_original = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
26 |
+
preprocessing_fn = smp.encoders.get_preprocessing_fn('resnet50')
|
27 |
+
transform = albu.Compose([
|
28 |
+
albu.Lambda(image=preprocessing_fn), albu.Resize(SHAPE_X, SHAPE_Y)
|
29 |
+
])
|
30 |
+
|
31 |
+
image_result = transform(image=image_original)["image"]
|
32 |
+
transform = transforms.ToTensor()
|
33 |
+
tensor = transform(image_result)
|
34 |
+
tensor = torch.unsqueeze(tensor, 0)
|
35 |
+
output = model.predict(tensor.float())
|
36 |
+
|
37 |
+
result, img_vis = filter_mask(output, image_original )
|
38 |
+
|
39 |
+
image = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
|
40 |
+
im_pil = Image.fromarray(image)
|
41 |
+
image = img_transform(im_pil).unsqueeze(0).to("cpu")
|
42 |
+
|
43 |
+
p = model_recog(image).softmax(-1)
|
44 |
+
pred, p = model_recog.tokenizer.decode(p)
|
45 |
+
print(f'{image_path}: {pred[0]}')
|
46 |
+
|
47 |
+
|
48 |
+
return img_vis, pred[0]
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
post.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def filter_mask(output, image):
|
5 |
+
image_h, image_w = image.shape[:2]
|
6 |
+
|
7 |
+
predict_mask = (output.squeeze().cpu().numpy().round())
|
8 |
+
predict_mask = predict_mask.astype('uint8')*255
|
9 |
+
predict_mask = cv2.resize( predict_mask, (image_w, image_h) )
|
10 |
+
|
11 |
+
ret, thresh = cv2.threshold(predict_mask, 127, 255, 0)
|
12 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
13 |
+
points = contours[0]
|
14 |
+
|
15 |
+
rect = cv2.minAreaRect(points)
|
16 |
+
box = cv2.boxPoints(rect)
|
17 |
+
box = np.int0(box)
|
18 |
+
|
19 |
+
img_vis = cv2.drawContours(image.copy(),[box],0,(255,0,0), 6)
|
20 |
+
|
21 |
+
(cX, cY), (w, h), angle = rect
|
22 |
+
|
23 |
+
if w<h:
|
24 |
+
angle -= 90
|
25 |
+
|
26 |
+
#(cX, cY) = (image_w // 2, image_h // 2)
|
27 |
+
|
28 |
+
M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
|
29 |
+
rotated = cv2.warpAffine(image, M, (round(image_w), round(image_h)))
|
30 |
+
|
31 |
+
ones = np.ones(shape=(len(points), 1))
|
32 |
+
points = np.squeeze(points)
|
33 |
+
points_ones = np.hstack([points, ones])
|
34 |
+
points_rotate = M.dot(points_ones.T).T
|
35 |
+
|
36 |
+
(cX, cY), (w, h), angle = cv2.minAreaRect(points_rotate.round().astype(int))
|
37 |
+
|
38 |
+
if angle < 45 :
|
39 |
+
crop_img = rotated[round(cY-h//2):round(cY+h//2), round(cX - w//2):round(cX + w//2)]
|
40 |
+
else:
|
41 |
+
crop_img = rotated[round(cY-w//2):round(cY+w//2), round(cX - h//2):round(cX + h//2)]
|
42 |
+
return crop_img, img_vis
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
gradio
|
5 |
+
segmentation_models_pytorch
|
6 |
+
albumentations
|
7 |
+
Pillow
|
strhub/__init__.py
ADDED
File without changes
|
strhub/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (134 Bytes). View file
|
|
strhub/data/.ipynb_checkpoints/dataset-checkpoint.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import glob
|
16 |
+
import io
|
17 |
+
import logging
|
18 |
+
import unicodedata
|
19 |
+
from pathlib import Path, PurePath
|
20 |
+
from typing import Callable, Optional, Union
|
21 |
+
|
22 |
+
import lmdb
|
23 |
+
from PIL import Image
|
24 |
+
from torch.utils.data import Dataset, ConcatDataset
|
25 |
+
|
26 |
+
from strhub.data.utils import CharsetAdapter
|
27 |
+
|
28 |
+
log = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
def build_tree_dataset(root: Union[PurePath, str], *args, **kwargs):
|
32 |
+
try:
|
33 |
+
kwargs.pop('root') # prevent 'root' from being passed via kwargs
|
34 |
+
except KeyError:
|
35 |
+
pass
|
36 |
+
root = Path(root).absolute()
|
37 |
+
log.info(f'dataset root:\t{root}')
|
38 |
+
datasets = []
|
39 |
+
for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True):
|
40 |
+
mdb = Path(mdb)
|
41 |
+
ds_name = str(mdb.parent.relative_to(root))
|
42 |
+
ds_root = str(mdb.parent.absolute())
|
43 |
+
dataset = LmdbDataset(ds_root, *args, **kwargs)
|
44 |
+
log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}')
|
45 |
+
datasets.append(dataset)
|
46 |
+
return ConcatDataset(datasets)
|
47 |
+
|
48 |
+
|
49 |
+
class LmdbDataset(Dataset):
|
50 |
+
"""Dataset interface to an LMDB database.
|
51 |
+
|
52 |
+
It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned
|
53 |
+
as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset.
|
54 |
+
Labels are transformed according to the charset.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, root: str, charset: str, max_label_len: int, min_image_dim: int = 0,
|
58 |
+
remove_whitespace: bool = True, normalize_unicode: bool = True,
|
59 |
+
unlabelled: bool = False, transform: Optional[Callable] = None):
|
60 |
+
self._env = None
|
61 |
+
self.root = root
|
62 |
+
self.unlabelled = unlabelled
|
63 |
+
self.transform = transform
|
64 |
+
self.labels = []
|
65 |
+
self.filtered_index_list = []
|
66 |
+
self.num_samples = self._preprocess_labels(charset, remove_whitespace, normalize_unicode,
|
67 |
+
max_label_len, min_image_dim)
|
68 |
+
|
69 |
+
def __del__(self):
|
70 |
+
if self._env is not None:
|
71 |
+
self._env.close()
|
72 |
+
self._env = None
|
73 |
+
|
74 |
+
def _create_env(self):
|
75 |
+
return lmdb.open(self.root, max_readers=1, readonly=True, create=False,
|
76 |
+
readahead=False, meminit=False, lock=False)
|
77 |
+
|
78 |
+
@property
|
79 |
+
def env(self):
|
80 |
+
if self._env is None:
|
81 |
+
self._env = self._create_env()
|
82 |
+
return self._env
|
83 |
+
|
84 |
+
def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim):
|
85 |
+
charset_adapter = CharsetAdapter(charset)
|
86 |
+
with self._create_env() as env, env.begin() as txn:
|
87 |
+
num_samples = int(txn.get('num-samples'.encode()))
|
88 |
+
if self.unlabelled:
|
89 |
+
return num_samples
|
90 |
+
for index in range(num_samples):
|
91 |
+
index += 1 # lmdb starts with 1
|
92 |
+
label_key = f'label-{index:09d}'.encode()
|
93 |
+
label = txn.get(label_key).decode()
|
94 |
+
# Normally, whitespace is removed from the labels.
|
95 |
+
if remove_whitespace:
|
96 |
+
label = ''.join(label.split())
|
97 |
+
# Normalize unicode composites (if any) and convert to compatible ASCII characters
|
98 |
+
if normalize_unicode:
|
99 |
+
label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode()
|
100 |
+
# Filter by length before removing unsupported characters. The original label might be too long.
|
101 |
+
if len(label) > max_label_len:
|
102 |
+
continue
|
103 |
+
label = charset_adapter(label)
|
104 |
+
# We filter out samples which don't contain any supported characters
|
105 |
+
if not label:
|
106 |
+
continue
|
107 |
+
# Filter images that are too small.
|
108 |
+
if min_image_dim > 0:
|
109 |
+
img_key = f'image-{index:09d}'.encode()
|
110 |
+
buf = io.BytesIO(txn.get(img_key))
|
111 |
+
w, h = Image.open(buf).size
|
112 |
+
if w < self.min_image_dim or h < self.min_image_dim:
|
113 |
+
continue
|
114 |
+
self.labels.append(label)
|
115 |
+
self.filtered_index_list.append(index)
|
116 |
+
return len(self.labels)
|
117 |
+
|
118 |
+
def __len__(self):
|
119 |
+
return self.num_samples
|
120 |
+
|
121 |
+
def __getitem__(self, index):
|
122 |
+
if self.unlabelled:
|
123 |
+
label = index
|
124 |
+
else:
|
125 |
+
label = self.labels[index]
|
126 |
+
index = self.filtered_index_list[index]
|
127 |
+
|
128 |
+
img_key = f'image-{index:09d}'.encode()
|
129 |
+
with self.env.begin() as txn:
|
130 |
+
imgbuf = txn.get(img_key)
|
131 |
+
buf = io.BytesIO(imgbuf)
|
132 |
+
img = Image.open(buf).convert('RGB')
|
133 |
+
|
134 |
+
if self.transform is not None:
|
135 |
+
img = self.transform(img)
|
136 |
+
|
137 |
+
return img, label
|
strhub/data/.ipynb_checkpoints/module-checkpoint.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from pathlib import PurePath
|
17 |
+
from typing import Optional, Callable, Sequence, Tuple
|
18 |
+
|
19 |
+
import pytorch_lightning as pl
|
20 |
+
from torch.utils.data import DataLoader
|
21 |
+
from torchvision import transforms as T
|
22 |
+
|
23 |
+
from .dataset import build_tree_dataset, LmdbDataset
|
24 |
+
|
25 |
+
|
26 |
+
class SceneTextDataModule(pl.LightningDataModule):
|
27 |
+
TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80')
|
28 |
+
TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80')
|
29 |
+
TEST_NEW = ('ArT', 'COCOv1.4', 'Uber')
|
30 |
+
TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW))
|
31 |
+
|
32 |
+
def __init__(self, root_dir: str, train_dir: str, img_size: Sequence[int], max_label_length: int,
|
33 |
+
charset_train: str, charset_test: str, batch_size: int, num_workers: int, augment: bool,
|
34 |
+
remove_whitespace: bool = True, normalize_unicode: bool = True,
|
35 |
+
min_image_dim: int = 0, rotation: int = 0, collate_fn: Optional[Callable] = None):
|
36 |
+
super().__init__()
|
37 |
+
self.root_dir = root_dir
|
38 |
+
self.train_dir = train_dir
|
39 |
+
self.img_size = tuple(img_size)
|
40 |
+
self.max_label_length = max_label_length
|
41 |
+
self.charset_train = charset_train
|
42 |
+
self.charset_test = charset_test
|
43 |
+
self.batch_size = batch_size
|
44 |
+
self.num_workers = num_workers
|
45 |
+
self.augment = augment
|
46 |
+
self.remove_whitespace = remove_whitespace
|
47 |
+
self.normalize_unicode = normalize_unicode
|
48 |
+
self.min_image_dim = min_image_dim
|
49 |
+
self.rotation = rotation
|
50 |
+
self.collate_fn = collate_fn
|
51 |
+
self._train_dataset = None
|
52 |
+
self._val_dataset = None
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def get_transform(img_size: Tuple[int], augment: bool = False, rotation: int = 0):
|
56 |
+
transforms = []
|
57 |
+
if augment:
|
58 |
+
from .augment import rand_augment_transform
|
59 |
+
transforms.append(rand_augment_transform())
|
60 |
+
if rotation:
|
61 |
+
transforms.append(lambda img: img.rotate(rotation, expand=True))
|
62 |
+
transforms.extend([
|
63 |
+
T.Resize(img_size, T.InterpolationMode.BICUBIC),
|
64 |
+
T.ToTensor(),
|
65 |
+
T.Normalize(0.5, 0.5)
|
66 |
+
])
|
67 |
+
return T.Compose(transforms)
|
68 |
+
|
69 |
+
@property
|
70 |
+
def train_dataset(self):
|
71 |
+
if self._train_dataset is None:
|
72 |
+
transform = self.get_transform(self.img_size, self.augment)
|
73 |
+
root = PurePath(self.root_dir, 'train', self.train_dir)
|
74 |
+
self._train_dataset = build_tree_dataset(root, self.charset_train, self.max_label_length,
|
75 |
+
self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
|
76 |
+
transform=transform)
|
77 |
+
return self._train_dataset
|
78 |
+
|
79 |
+
@property
|
80 |
+
def val_dataset(self):
|
81 |
+
if self._val_dataset is None:
|
82 |
+
transform = self.get_transform(self.img_size)
|
83 |
+
root = PurePath(self.root_dir, 'val')
|
84 |
+
self._val_dataset = build_tree_dataset(root, self.charset_test, self.max_label_length,
|
85 |
+
self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
|
86 |
+
transform=transform)
|
87 |
+
return self._val_dataset
|
88 |
+
|
89 |
+
def train_dataloader(self):
|
90 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
|
91 |
+
num_workers=self.num_workers, persistent_workers=self.num_workers > 0,
|
92 |
+
pin_memory=True, collate_fn=self.collate_fn)
|
93 |
+
|
94 |
+
def val_dataloader(self):
|
95 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size,
|
96 |
+
num_workers=self.num_workers, persistent_workers=self.num_workers > 0,
|
97 |
+
pin_memory=True, collate_fn=self.collate_fn)
|
98 |
+
|
99 |
+
def test_dataloaders(self, subset):
|
100 |
+
transform = self.get_transform(self.img_size, rotation=self.rotation)
|
101 |
+
root = PurePath(self.root_dir, 'test')
|
102 |
+
datasets = {s: LmdbDataset(str(root / s), self.charset_test, self.max_label_length,
|
103 |
+
self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
|
104 |
+
transform=transform) for s in subset}
|
105 |
+
return {k: DataLoader(v, batch_size=self.batch_size, num_workers=self.num_workers,
|
106 |
+
pin_memory=True, collate_fn=self.collate_fn)
|
107 |
+
for k, v in datasets.items()}
|
strhub/data/__init__.py
ADDED
File without changes
|
strhub/data/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (139 Bytes). View file
|
|
strhub/data/__pycache__/aa_overrides.cpython-37.pyc
ADDED
Binary file (1.22 kB). View file
|
|
strhub/data/__pycache__/augment.cpython-37.pyc
ADDED
Binary file (3.54 kB). View file
|
|
strhub/data/__pycache__/dataset.cpython-37.pyc
ADDED
Binary file (3.97 kB). View file
|
|
strhub/data/__pycache__/module.cpython-37.pyc
ADDED
Binary file (4.11 kB). View file
|
|
strhub/data/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (6.83 kB). View file
|
|
strhub/data/aa_overrides.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Extends default ops to accept optional parameters."""
|
17 |
+
from functools import partial
|
18 |
+
|
19 |
+
from timm.data.auto_augment import _LEVEL_DENOM, _randomly_negate, LEVEL_TO_ARG, NAME_TO_OP, rotate
|
20 |
+
|
21 |
+
|
22 |
+
def rotate_expand(img, degrees, **kwargs):
|
23 |
+
"""Rotate operation with expand=True to avoid cutting off the characters"""
|
24 |
+
kwargs['expand'] = True
|
25 |
+
return rotate(img, degrees, **kwargs)
|
26 |
+
|
27 |
+
|
28 |
+
def _level_to_arg(level, hparams, key, default):
|
29 |
+
magnitude = hparams.get(key, default)
|
30 |
+
level = (level / _LEVEL_DENOM) * magnitude
|
31 |
+
level = _randomly_negate(level)
|
32 |
+
return level,
|
33 |
+
|
34 |
+
|
35 |
+
def apply():
|
36 |
+
# Overrides
|
37 |
+
NAME_TO_OP.update({
|
38 |
+
'Rotate': rotate_expand
|
39 |
+
})
|
40 |
+
LEVEL_TO_ARG.update({
|
41 |
+
'Rotate': partial(_level_to_arg, key='rotate_deg', default=30.),
|
42 |
+
'ShearX': partial(_level_to_arg, key='shear_x_pct', default=0.3),
|
43 |
+
'ShearY': partial(_level_to_arg, key='shear_y_pct', default=0.3),
|
44 |
+
'TranslateXRel': partial(_level_to_arg, key='translate_x_pct', default=0.45),
|
45 |
+
'TranslateYRel': partial(_level_to_arg, key='translate_y_pct', default=0.45),
|
46 |
+
})
|
strhub/data/augment.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
import imgaug.augmenters as iaa
|
19 |
+
import numpy as np
|
20 |
+
from PIL import ImageFilter, Image
|
21 |
+
from timm.data import auto_augment
|
22 |
+
|
23 |
+
from strhub.data import aa_overrides
|
24 |
+
|
25 |
+
aa_overrides.apply()
|
26 |
+
|
27 |
+
_OP_CACHE = {}
|
28 |
+
|
29 |
+
|
30 |
+
def _get_op(key, factory):
|
31 |
+
try:
|
32 |
+
op = _OP_CACHE[key]
|
33 |
+
except KeyError:
|
34 |
+
op = factory()
|
35 |
+
_OP_CACHE[key] = op
|
36 |
+
return op
|
37 |
+
|
38 |
+
|
39 |
+
def _get_param(level, img, max_dim_factor, min_level=1):
|
40 |
+
max_level = max(min_level, max_dim_factor * max(img.size))
|
41 |
+
return round(min(level, max_level))
|
42 |
+
|
43 |
+
|
44 |
+
def gaussian_blur(img, radius, **__):
|
45 |
+
radius = _get_param(radius, img, 0.02)
|
46 |
+
key = 'gaussian_blur_' + str(radius)
|
47 |
+
op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius))
|
48 |
+
return img.filter(op)
|
49 |
+
|
50 |
+
|
51 |
+
def motion_blur(img, k, **__):
|
52 |
+
k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values
|
53 |
+
key = 'motion_blur_' + str(k)
|
54 |
+
op = _get_op(key, lambda: iaa.MotionBlur(k))
|
55 |
+
return Image.fromarray(op(image=np.asarray(img)))
|
56 |
+
|
57 |
+
|
58 |
+
def gaussian_noise(img, scale, **_):
|
59 |
+
scale = _get_param(scale, img, 0.25) | 1 # bin to odd values
|
60 |
+
key = 'gaussian_noise_' + str(scale)
|
61 |
+
op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale))
|
62 |
+
return Image.fromarray(op(image=np.asarray(img)))
|
63 |
+
|
64 |
+
|
65 |
+
def poisson_noise(img, lam, **_):
|
66 |
+
lam = _get_param(lam, img, 0.2) | 1 # bin to odd values
|
67 |
+
key = 'poisson_noise_' + str(lam)
|
68 |
+
op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam))
|
69 |
+
return Image.fromarray(op(image=np.asarray(img)))
|
70 |
+
|
71 |
+
|
72 |
+
def _level_to_arg(level, _hparams, max):
|
73 |
+
level = max * level / auto_augment._LEVEL_DENOM
|
74 |
+
return level,
|
75 |
+
|
76 |
+
|
77 |
+
_RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy()
|
78 |
+
_RAND_TRANSFORMS.remove('SharpnessIncreasing') # remove, interferes with *blur ops
|
79 |
+
_RAND_TRANSFORMS.extend([
|
80 |
+
'GaussianBlur',
|
81 |
+
# 'MotionBlur',
|
82 |
+
# 'GaussianNoise',
|
83 |
+
'PoissonNoise'
|
84 |
+
])
|
85 |
+
auto_augment.LEVEL_TO_ARG.update({
|
86 |
+
'GaussianBlur': partial(_level_to_arg, max=4),
|
87 |
+
'MotionBlur': partial(_level_to_arg, max=20),
|
88 |
+
'GaussianNoise': partial(_level_to_arg, max=0.1 * 255),
|
89 |
+
'PoissonNoise': partial(_level_to_arg, max=40)
|
90 |
+
})
|
91 |
+
auto_augment.NAME_TO_OP.update({
|
92 |
+
'GaussianBlur': gaussian_blur,
|
93 |
+
'MotionBlur': motion_blur,
|
94 |
+
'GaussianNoise': gaussian_noise,
|
95 |
+
'PoissonNoise': poisson_noise
|
96 |
+
})
|
97 |
+
|
98 |
+
|
99 |
+
def rand_augment_transform(magnitude=5, num_layers=3):
|
100 |
+
# These are tuned for magnitude=5, which means that effective magnitudes are half of these values.
|
101 |
+
hparams = {
|
102 |
+
'rotate_deg': 30,
|
103 |
+
'shear_x_pct': 0.9,
|
104 |
+
'shear_y_pct': 0.2,
|
105 |
+
'translate_x_pct': 0.10,
|
106 |
+
'translate_y_pct': 0.30
|
107 |
+
}
|
108 |
+
ra_ops = auto_augment.rand_augment_ops(magnitude, hparams, transforms=_RAND_TRANSFORMS)
|
109 |
+
# Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice)
|
110 |
+
choice_weights = [1. / len(ra_ops) for _ in range(len(ra_ops))]
|
111 |
+
return auto_augment.RandAugment(ra_ops, num_layers, choice_weights)
|
strhub/data/dataset.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import glob
|
16 |
+
import io
|
17 |
+
import logging
|
18 |
+
import unicodedata
|
19 |
+
from pathlib import Path, PurePath
|
20 |
+
from typing import Callable, Optional, Union
|
21 |
+
|
22 |
+
import lmdb
|
23 |
+
from PIL import Image
|
24 |
+
from torch.utils.data import Dataset, ConcatDataset
|
25 |
+
|
26 |
+
from strhub.data.utils import CharsetAdapter
|
27 |
+
|
28 |
+
log = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
def build_tree_dataset(root: Union[PurePath, str], *args, **kwargs):
|
32 |
+
try:
|
33 |
+
kwargs.pop('root') # prevent 'root' from being passed via kwargs
|
34 |
+
except KeyError:
|
35 |
+
pass
|
36 |
+
root = Path(root).absolute()
|
37 |
+
log.info(f'dataset root:\t{root}')
|
38 |
+
datasets = []
|
39 |
+
for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True):
|
40 |
+
mdb = Path(mdb)
|
41 |
+
ds_name = str(mdb.parent.relative_to(root))
|
42 |
+
ds_root = str(mdb.parent.absolute())
|
43 |
+
dataset = LmdbDataset(ds_root, *args, **kwargs)
|
44 |
+
log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}')
|
45 |
+
datasets.append(dataset)
|
46 |
+
return ConcatDataset(datasets)
|
47 |
+
|
48 |
+
|
49 |
+
class LmdbDataset(Dataset):
|
50 |
+
"""Dataset interface to an LMDB database.
|
51 |
+
|
52 |
+
It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned
|
53 |
+
as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset.
|
54 |
+
Labels are transformed according to the charset.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, root: str, charset: str, max_label_len: int, min_image_dim: int = 0,
|
58 |
+
remove_whitespace: bool = True, normalize_unicode: bool = True,
|
59 |
+
unlabelled: bool = False, transform: Optional[Callable] = None):
|
60 |
+
self._env = None
|
61 |
+
self.root = root
|
62 |
+
self.unlabelled = unlabelled
|
63 |
+
self.transform = transform
|
64 |
+
self.labels = []
|
65 |
+
self.filtered_index_list = []
|
66 |
+
self.num_samples = self._preprocess_labels(charset, remove_whitespace, normalize_unicode,
|
67 |
+
max_label_len, min_image_dim)
|
68 |
+
|
69 |
+
def __del__(self):
|
70 |
+
if self._env is not None:
|
71 |
+
self._env.close()
|
72 |
+
self._env = None
|
73 |
+
|
74 |
+
def _create_env(self):
|
75 |
+
return lmdb.open(self.root, max_readers=1, readonly=True, create=False,
|
76 |
+
readahead=False, meminit=False, lock=False)
|
77 |
+
|
78 |
+
@property
|
79 |
+
def env(self):
|
80 |
+
if self._env is None:
|
81 |
+
self._env = self._create_env()
|
82 |
+
return self._env
|
83 |
+
|
84 |
+
def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim):
|
85 |
+
charset_adapter = CharsetAdapter(charset)
|
86 |
+
with self._create_env() as env, env.begin() as txn:
|
87 |
+
num_samples = int(txn.get('num-samples'.encode()))
|
88 |
+
if self.unlabelled:
|
89 |
+
return num_samples
|
90 |
+
for index in range(num_samples):
|
91 |
+
index += 1 # lmdb starts with 1
|
92 |
+
label_key = f'label-{index:09d}'.encode()
|
93 |
+
label = txn.get(label_key).decode()
|
94 |
+
# Normally, whitespace is removed from the labels.
|
95 |
+
if remove_whitespace:
|
96 |
+
label = ''.join(label.split())
|
97 |
+
# Normalize unicode composites (if any) and convert to compatible ASCII characters
|
98 |
+
if normalize_unicode:
|
99 |
+
label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode()
|
100 |
+
# Filter by length before removing unsupported characters. The original label might be too long.
|
101 |
+
if len(label) > max_label_len:
|
102 |
+
continue
|
103 |
+
label = charset_adapter(label)
|
104 |
+
# We filter out samples which don't contain any supported characters
|
105 |
+
if not label:
|
106 |
+
continue
|
107 |
+
# Filter images that are too small.
|
108 |
+
if min_image_dim > 0:
|
109 |
+
img_key = f'image-{index:09d}'.encode()
|
110 |
+
buf = io.BytesIO(txn.get(img_key))
|
111 |
+
w, h = Image.open(buf).size
|
112 |
+
if w < self.min_image_dim or h < self.min_image_dim:
|
113 |
+
continue
|
114 |
+
self.labels.append(label)
|
115 |
+
self.filtered_index_list.append(index)
|
116 |
+
return len(self.labels)
|
117 |
+
|
118 |
+
def __len__(self):
|
119 |
+
return self.num_samples
|
120 |
+
|
121 |
+
def __getitem__(self, index):
|
122 |
+
if self.unlabelled:
|
123 |
+
label = index
|
124 |
+
else:
|
125 |
+
label = self.labels[index]
|
126 |
+
index = self.filtered_index_list[index]
|
127 |
+
|
128 |
+
img_key = f'image-{index:09d}'.encode()
|
129 |
+
with self.env.begin() as txn:
|
130 |
+
imgbuf = txn.get(img_key)
|
131 |
+
buf = io.BytesIO(imgbuf)
|
132 |
+
img = Image.open(buf).convert('RGB')
|
133 |
+
|
134 |
+
if self.transform is not None:
|
135 |
+
img = self.transform(img)
|
136 |
+
|
137 |
+
return img, label
|
strhub/data/module.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from pathlib import PurePath
|
17 |
+
from typing import Optional, Callable, Sequence, Tuple
|
18 |
+
|
19 |
+
import pytorch_lightning as pl
|
20 |
+
from torch.utils.data import DataLoader
|
21 |
+
from torchvision import transforms as T
|
22 |
+
|
23 |
+
from .dataset import build_tree_dataset, LmdbDataset
|
24 |
+
|
25 |
+
|
26 |
+
class SceneTextDataModule(pl.LightningDataModule):
|
27 |
+
TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80')
|
28 |
+
TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80')
|
29 |
+
TEST_NEW = ('ArT', 'COCOv1.4', 'Uber')
|
30 |
+
TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW))
|
31 |
+
|
32 |
+
def __init__(self, root_dir: str, train_dir: str, img_size: Sequence[int], max_label_length: int,
|
33 |
+
charset_train: str, charset_test: str, batch_size: int, num_workers: int, augment: bool,
|
34 |
+
remove_whitespace: bool = True, normalize_unicode: bool = True,
|
35 |
+
min_image_dim: int = 0, rotation: int = 0, collate_fn: Optional[Callable] = None):
|
36 |
+
super().__init__()
|
37 |
+
self.root_dir = root_dir
|
38 |
+
self.train_dir = train_dir
|
39 |
+
self.img_size = tuple(img_size)
|
40 |
+
self.max_label_length = max_label_length
|
41 |
+
self.charset_train = charset_train
|
42 |
+
self.charset_test = charset_test
|
43 |
+
self.batch_size = batch_size
|
44 |
+
self.num_workers = num_workers
|
45 |
+
self.augment = augment
|
46 |
+
self.remove_whitespace = remove_whitespace
|
47 |
+
self.normalize_unicode = normalize_unicode
|
48 |
+
self.min_image_dim = min_image_dim
|
49 |
+
self.rotation = rotation
|
50 |
+
self.collate_fn = collate_fn
|
51 |
+
self._train_dataset = None
|
52 |
+
self._val_dataset = None
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def get_transform(img_size: Tuple[int], augment: bool = False, rotation: int = 0):
|
56 |
+
transforms = []
|
57 |
+
if augment:
|
58 |
+
from .augment import rand_augment_transform
|
59 |
+
transforms.append(rand_augment_transform())
|
60 |
+
if rotation:
|
61 |
+
transforms.append(lambda img: img.rotate(rotation, expand=True))
|
62 |
+
transforms.extend([
|
63 |
+
T.Resize(img_size, T.InterpolationMode.BICUBIC),
|
64 |
+
T.ToTensor(),
|
65 |
+
T.Normalize(0.5, 0.5)
|
66 |
+
])
|
67 |
+
return T.Compose(transforms)
|
68 |
+
|
69 |
+
@property
|
70 |
+
def train_dataset(self):
|
71 |
+
if self._train_dataset is None:
|
72 |
+
transform = self.get_transform(self.img_size, self.augment)
|
73 |
+
root = PurePath(self.root_dir, 'train', self.train_dir)
|
74 |
+
self._train_dataset = build_tree_dataset(root, self.charset_train, self.max_label_length,
|
75 |
+
self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
|
76 |
+
transform=transform)
|
77 |
+
return self._train_dataset
|
78 |
+
|
79 |
+
@property
|
80 |
+
def val_dataset(self):
|
81 |
+
if self._val_dataset is None:
|
82 |
+
transform = self.get_transform(self.img_size)
|
83 |
+
root = PurePath(self.root_dir, 'val')
|
84 |
+
self._val_dataset = build_tree_dataset(root, self.charset_test, self.max_label_length,
|
85 |
+
self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
|
86 |
+
transform=transform)
|
87 |
+
return self._val_dataset
|
88 |
+
|
89 |
+
def train_dataloader(self):
|
90 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
|
91 |
+
num_workers=self.num_workers, persistent_workers=self.num_workers > 0,
|
92 |
+
pin_memory=True, collate_fn=self.collate_fn)
|
93 |
+
|
94 |
+
def val_dataloader(self):
|
95 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size,
|
96 |
+
num_workers=self.num_workers, persistent_workers=self.num_workers > 0,
|
97 |
+
pin_memory=True, collate_fn=self.collate_fn)
|
98 |
+
|
99 |
+
def test_dataloaders(self, subset):
|
100 |
+
transform = self.get_transform(self.img_size, rotation=self.rotation)
|
101 |
+
root = PurePath(self.root_dir, 'test')
|
102 |
+
datasets = {s: LmdbDataset(str(root / s), self.charset_test, self.max_label_length,
|
103 |
+
self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
|
104 |
+
transform=transform) for s in subset}
|
105 |
+
return {k: DataLoader(v, batch_size=self.batch_size, num_workers=self.num_workers,
|
106 |
+
pin_memory=True, collate_fn=self.collate_fn)
|
107 |
+
for k, v in datasets.items()}
|
strhub/data/utils.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import re
|
17 |
+
from abc import ABC, abstractmethod
|
18 |
+
from itertools import groupby
|
19 |
+
from typing import List, Optional, Tuple
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from torch import Tensor
|
23 |
+
from torch.nn.utils.rnn import pad_sequence
|
24 |
+
|
25 |
+
|
26 |
+
class CharsetAdapter:
|
27 |
+
"""Transforms labels according to the target charset."""
|
28 |
+
|
29 |
+
def __init__(self, target_charset) -> None:
|
30 |
+
super().__init__()
|
31 |
+
self.lowercase_only = target_charset == target_charset.lower()
|
32 |
+
self.uppercase_only = target_charset == target_charset.upper()
|
33 |
+
self.unsupported = f'[^{re.escape(target_charset)}]'
|
34 |
+
|
35 |
+
def __call__(self, label):
|
36 |
+
if self.lowercase_only:
|
37 |
+
label = label.lower()
|
38 |
+
elif self.uppercase_only:
|
39 |
+
label = label.upper()
|
40 |
+
# Remove unsupported characters
|
41 |
+
label = re.sub(self.unsupported, '', label)
|
42 |
+
return label
|
43 |
+
|
44 |
+
|
45 |
+
class BaseTokenizer(ABC):
|
46 |
+
|
47 |
+
def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
|
48 |
+
self._itos = specials_first + tuple(charset) + specials_last
|
49 |
+
self._stoi = {s: i for i, s in enumerate(self._itos)}
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self._itos)
|
53 |
+
|
54 |
+
def _tok2ids(self, tokens: str) -> List[int]:
|
55 |
+
return [self._stoi[s] for s in tokens]
|
56 |
+
|
57 |
+
def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
|
58 |
+
tokens = [self._itos[i] for i in token_ids]
|
59 |
+
return ''.join(tokens) if join else tokens
|
60 |
+
|
61 |
+
@abstractmethod
|
62 |
+
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
|
63 |
+
"""Encode a batch of labels to a representation suitable for the model.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
labels: List of labels. Each can be of arbitrary length.
|
67 |
+
device: Create tensor on this device.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
Batched tensor representation padded to the max label length. Shape: N, L
|
71 |
+
"""
|
72 |
+
raise NotImplementedError
|
73 |
+
|
74 |
+
@abstractmethod
|
75 |
+
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
|
76 |
+
"""Internal method which performs the necessary filtering prior to decoding."""
|
77 |
+
raise NotImplementedError
|
78 |
+
|
79 |
+
def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
|
80 |
+
"""Decode a batch of token distributions.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
token_dists: softmax probabilities over the token distribution. Shape: N, L, C
|
84 |
+
raw: return unprocessed labels (will return list of list of strings)
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
list of string labels (arbitrary length) and
|
88 |
+
their corresponding sequence probabilities as a list of Tensors
|
89 |
+
"""
|
90 |
+
batch_tokens = []
|
91 |
+
batch_probs = []
|
92 |
+
for dist in token_dists:
|
93 |
+
probs, ids = dist.max(-1) # greedy selection
|
94 |
+
if not raw:
|
95 |
+
probs, ids = self._filter(probs, ids)
|
96 |
+
tokens = self._ids2tok(ids, not raw)
|
97 |
+
batch_tokens.append(tokens)
|
98 |
+
batch_probs.append(probs)
|
99 |
+
return batch_tokens, batch_probs
|
100 |
+
|
101 |
+
|
102 |
+
class Tokenizer(BaseTokenizer):
|
103 |
+
BOS = '[B]'
|
104 |
+
EOS = '[E]'
|
105 |
+
PAD = '[P]'
|
106 |
+
|
107 |
+
def __init__(self, charset: str) -> None:
|
108 |
+
specials_first = (self.EOS,)
|
109 |
+
specials_last = (self.BOS, self.PAD)
|
110 |
+
super().__init__(charset, specials_first, specials_last)
|
111 |
+
self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last]
|
112 |
+
|
113 |
+
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
|
114 |
+
batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
|
115 |
+
for y in labels]
|
116 |
+
return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
|
117 |
+
|
118 |
+
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
|
119 |
+
ids = ids.tolist()
|
120 |
+
try:
|
121 |
+
eos_idx = ids.index(self.eos_id)
|
122 |
+
except ValueError:
|
123 |
+
eos_idx = len(ids) # Nothing to truncate.
|
124 |
+
# Truncate after EOS
|
125 |
+
ids = ids[:eos_idx]
|
126 |
+
probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
|
127 |
+
return probs, ids
|
128 |
+
|
129 |
+
|
130 |
+
class CTCTokenizer(BaseTokenizer):
|
131 |
+
BLANK = '[B]'
|
132 |
+
|
133 |
+
def __init__(self, charset: str) -> None:
|
134 |
+
# BLANK uses index == 0 by default
|
135 |
+
super().__init__(charset, specials_first=(self.BLANK,))
|
136 |
+
self.blank_id = self._stoi[self.BLANK]
|
137 |
+
|
138 |
+
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
|
139 |
+
# We use a padded representation since we don't want to use CUDNN's CTC implementation
|
140 |
+
batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels]
|
141 |
+
return pad_sequence(batch, batch_first=True, padding_value=self.blank_id)
|
142 |
+
|
143 |
+
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
|
144 |
+
# Best path decoding:
|
145 |
+
ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens
|
146 |
+
ids = [x for x in ids if x != self.blank_id] # Remove BLANKs
|
147 |
+
# `probs` is just pass-through since all positions are considered part of the path
|
148 |
+
return probs, ids
|
strhub/models/.ipynb_checkpoints/base-checkpoint.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import math
|
17 |
+
from abc import ABC, abstractmethod
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import Optional, Tuple, List
|
20 |
+
|
21 |
+
import pytorch_lightning as pl
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from nltk import edit_distance
|
25 |
+
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
|
26 |
+
from timm.optim import create_optimizer_v2
|
27 |
+
from torch import Tensor
|
28 |
+
from torch.optim import Optimizer
|
29 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
30 |
+
|
31 |
+
from strhub.data.utils import CharsetAdapter, CTCTokenizer, Tokenizer, BaseTokenizer
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class BatchResult:
|
36 |
+
num_samples: int
|
37 |
+
correct: int
|
38 |
+
ned: float
|
39 |
+
confidence: float
|
40 |
+
label_length: int
|
41 |
+
loss: Tensor
|
42 |
+
loss_numel: int
|
43 |
+
|
44 |
+
|
45 |
+
class BaseSystem(pl.LightningModule, ABC):
|
46 |
+
|
47 |
+
def __init__(self, tokenizer: BaseTokenizer, charset_test: str,
|
48 |
+
batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None:
|
49 |
+
super().__init__()
|
50 |
+
self.tokenizer = tokenizer
|
51 |
+
self.charset_adapter = CharsetAdapter(charset_test)
|
52 |
+
self.batch_size = batch_size
|
53 |
+
self.lr = lr
|
54 |
+
self.warmup_pct = warmup_pct
|
55 |
+
self.weight_decay = weight_decay
|
56 |
+
|
57 |
+
@abstractmethod
|
58 |
+
def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
|
59 |
+
"""Inference
|
60 |
+
|
61 |
+
Args:
|
62 |
+
images: Batch of images. Shape: N, Ch, H, W
|
63 |
+
max_length: Max sequence length of the output. If None, will use default.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
|
67 |
+
"""
|
68 |
+
raise NotImplementedError
|
69 |
+
|
70 |
+
@abstractmethod
|
71 |
+
def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
|
72 |
+
"""Like forward(), but also computes the loss (calls forward() internally).
|
73 |
+
|
74 |
+
Args:
|
75 |
+
images: Batch of images. Shape: N, Ch, H, W
|
76 |
+
labels: Text labels of the images
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
|
80 |
+
loss: mean loss for the batch
|
81 |
+
loss_numel: number of elements the loss was calculated from
|
82 |
+
"""
|
83 |
+
raise NotImplementedError
|
84 |
+
|
85 |
+
def configure_optimizers(self):
|
86 |
+
agb = self.trainer.accumulate_grad_batches
|
87 |
+
# Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
|
88 |
+
lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.
|
89 |
+
lr = lr_scale * self.lr
|
90 |
+
optim = create_optimizer_v2(self, 'adamw', lr, self.weight_decay)
|
91 |
+
sched = OneCycleLR(optim, lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct,
|
92 |
+
cycle_momentum=False)
|
93 |
+
return {'optimizer': optim, 'lr_scheduler': {'scheduler': sched, 'interval': 'step'}}
|
94 |
+
|
95 |
+
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
|
96 |
+
optimizer.zero_grad(set_to_none=True)
|
97 |
+
|
98 |
+
def _eval_step(self, batch, validation: bool) -> Optional[STEP_OUTPUT]:
|
99 |
+
images, labels = batch
|
100 |
+
|
101 |
+
correct = 0
|
102 |
+
total = 0
|
103 |
+
ned = 0
|
104 |
+
confidence = 0
|
105 |
+
label_length = 0
|
106 |
+
if validation:
|
107 |
+
logits, loss, loss_numel = self.forward_logits_loss(images, labels)
|
108 |
+
else:
|
109 |
+
# At test-time, we shouldn't specify a max_label_length because the test-time charset used
|
110 |
+
# might be different from the train-time charset. max_label_length in eval_logits_loss() is computed
|
111 |
+
# based on the transformed label, which could be wrong if the actual gt label contains characters existing
|
112 |
+
# in the train-time charset but not in the test-time charset. For example, "aishahaleyes.blogspot.com"
|
113 |
+
# is exactly 25 characters, but if processed by CharsetAdapter for the 36-char set, it becomes 23 characters
|
114 |
+
# long only, which sets max_label_length = 23. This will cause the model prediction to be truncated.
|
115 |
+
logits = self.forward(images)
|
116 |
+
loss = loss_numel = None # Only used for validation; not needed at test-time.
|
117 |
+
|
118 |
+
probs = logits.softmax(-1)
|
119 |
+
preds, probs = self.tokenizer.decode(probs)
|
120 |
+
for pred, prob, gt in zip(preds, probs, labels):
|
121 |
+
confidence += prob.prod().item()
|
122 |
+
pred = self.charset_adapter(pred)
|
123 |
+
# Follow ICDAR 2019 definition of N.E.D.
|
124 |
+
ned += edit_distance(pred, gt) / max(len(pred), len(gt))
|
125 |
+
if pred == gt:
|
126 |
+
correct += 1
|
127 |
+
total += 1
|
128 |
+
label_length += len(pred)
|
129 |
+
return dict(output=BatchResult(total, correct, ned, confidence, label_length, loss, loss_numel))
|
130 |
+
|
131 |
+
@staticmethod
|
132 |
+
def _aggregate_results(outputs: EPOCH_OUTPUT) -> Tuple[float, float, float]:
|
133 |
+
if not outputs:
|
134 |
+
return 0., 0., 0.
|
135 |
+
total_loss = 0
|
136 |
+
total_loss_numel = 0
|
137 |
+
total_n_correct = 0
|
138 |
+
total_norm_ED = 0
|
139 |
+
total_size = 0
|
140 |
+
for result in outputs:
|
141 |
+
result = result['output']
|
142 |
+
total_loss += result.loss_numel * result.loss
|
143 |
+
total_loss_numel += result.loss_numel
|
144 |
+
total_n_correct += result.correct
|
145 |
+
total_norm_ED += result.ned
|
146 |
+
total_size += result.num_samples
|
147 |
+
acc = total_n_correct / total_size
|
148 |
+
ned = (1 - total_norm_ED / total_size)
|
149 |
+
loss = total_loss / total_loss_numel
|
150 |
+
return acc, ned, loss
|
151 |
+
|
152 |
+
def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
|
153 |
+
return self._eval_step(batch, True)
|
154 |
+
|
155 |
+
def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
|
156 |
+
acc, ned, loss = self._aggregate_results(outputs)
|
157 |
+
self.log('val_accuracy', 100 * acc, sync_dist=True)
|
158 |
+
self.log('val_NED', 100 * ned, sync_dist=True)
|
159 |
+
self.log('val_loss', loss, sync_dist=True)
|
160 |
+
self.log('hp_metric', acc, sync_dist=True)
|
161 |
+
|
162 |
+
def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
|
163 |
+
return self._eval_step(batch, False)
|
164 |
+
|
165 |
+
|
166 |
+
class CrossEntropySystem(BaseSystem):
|
167 |
+
|
168 |
+
def __init__(self, charset_train: str, charset_test: str,
|
169 |
+
batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None:
|
170 |
+
tokenizer = Tokenizer(charset_train)
|
171 |
+
super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
|
172 |
+
self.bos_id = tokenizer.bos_id
|
173 |
+
self.eos_id = tokenizer.eos_id
|
174 |
+
self.pad_id = tokenizer.pad_id
|
175 |
+
|
176 |
+
def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
|
177 |
+
targets = self.tokenizer.encode(labels, self.device)
|
178 |
+
targets = targets[:, 1:] # Discard <bos>
|
179 |
+
max_len = targets.shape[1] - 1 # exclude <eos> from count
|
180 |
+
logits = self.forward(images, max_len)
|
181 |
+
loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id)
|
182 |
+
loss_numel = (targets != self.pad_id).sum()
|
183 |
+
return logits, loss, loss_numel
|
184 |
+
|
185 |
+
|
186 |
+
class CTCSystem(BaseSystem):
|
187 |
+
|
188 |
+
def __init__(self, charset_train: str, charset_test: str,
|
189 |
+
batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None:
|
190 |
+
tokenizer = CTCTokenizer(charset_train)
|
191 |
+
super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
|
192 |
+
self.blank_id = tokenizer.blank_id
|
193 |
+
|
194 |
+
def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
|
195 |
+
targets = self.tokenizer.encode(labels, self.device)
|
196 |
+
logits = self.forward(images)
|
197 |
+
log_probs = logits.log_softmax(-1).transpose(0, 1) # swap batch and seq. dims
|
198 |
+
T, N, _ = log_probs.shape
|
199 |
+
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long, device=self.device)
|
200 |
+
target_lengths = torch.as_tensor(list(map(len, labels)), dtype=torch.long, device=self.device)
|
201 |
+
loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_id, zero_infinity=True)
|
202 |
+
return logits, loss, N
|
strhub/models/.ipynb_checkpoints/modules-checkpoint.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""Shared modules used by CRNN and TRBA"""
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class BidirectionalLSTM(nn.Module):
|
6 |
+
"""Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py"""
|
7 |
+
|
8 |
+
def __init__(self, input_size, hidden_size, output_size):
|
9 |
+
super().__init__()
|
10 |
+
self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
|
11 |
+
self.linear = nn.Linear(hidden_size * 2, output_size)
|
12 |
+
|
13 |
+
def forward(self, input):
|
14 |
+
"""
|
15 |
+
input : visual feature [batch_size x T x input_size], T = num_steps.
|
16 |
+
output : contextual feature [batch_size x T x output_size]
|
17 |
+
"""
|
18 |
+
recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
|
19 |
+
output = self.linear(recurrent) # batch_size x T x output_size
|
20 |
+
return output
|
strhub/models/.ipynb_checkpoints/utils-checkpoint.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import PurePath
|
2 |
+
from typing import Sequence
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
import yaml
|
8 |
+
|
9 |
+
|
10 |
+
class InvalidModelError(RuntimeError):
|
11 |
+
"""Exception raised for any model-related error (creation, loading)"""
|
12 |
+
|
13 |
+
|
14 |
+
_WEIGHTS_URL = {
|
15 |
+
'parseq-tiny': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq_tiny-e7a21b54.pt',
|
16 |
+
'parseq': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq-bb5792a6.pt',
|
17 |
+
'abinet': 'https://github.com/baudm/parseq/releases/download/v1.0.0/abinet-1d1e373e.pt',
|
18 |
+
'trba': 'https://github.com/baudm/parseq/releases/download/v1.0.0/trba-cfaed284.pt',
|
19 |
+
'vitstr': 'https://github.com/baudm/parseq/releases/download/v1.0.0/vitstr-26d0fcf4.pt',
|
20 |
+
'crnn': 'https://github.com/baudm/parseq/releases/download/v1.0.0/crnn-679d0e31.pt',
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
def _get_config(experiment: str, **kwargs):
|
25 |
+
"""Emulates hydra config resolution"""
|
26 |
+
root = PurePath(__file__).parents[2]
|
27 |
+
with open(root / 'configs/main.yaml', 'r') as f:
|
28 |
+
config = yaml.load(f, yaml.Loader)['model']
|
29 |
+
with open(root / f'configs/charset/94_full.yaml', 'r') as f:
|
30 |
+
config.update(yaml.load(f, yaml.Loader)['model'])
|
31 |
+
with open(root / f'configs/experiment/{experiment}.yaml', 'r') as f:
|
32 |
+
exp = yaml.load(f, yaml.Loader)
|
33 |
+
# Apply base model config
|
34 |
+
model = exp['defaults'][0]['override /model']
|
35 |
+
with open(root / f'configs/model/{model}.yaml', 'r') as f:
|
36 |
+
config.update(yaml.load(f, yaml.Loader))
|
37 |
+
# Apply experiment config
|
38 |
+
if 'model' in exp:
|
39 |
+
config.update(exp['model'])
|
40 |
+
config.update(kwargs)
|
41 |
+
# Workaround for now: manually cast the lr to the correct type.
|
42 |
+
config['lr'] = float(config['lr'])
|
43 |
+
return config
|
44 |
+
|
45 |
+
|
46 |
+
def _get_model_class(key):
|
47 |
+
if 'abinet' in key:
|
48 |
+
from .abinet.system import ABINet as ModelClass
|
49 |
+
elif 'crnn' in key:
|
50 |
+
from .crnn.system import CRNN as ModelClass
|
51 |
+
elif 'parseq' in key:
|
52 |
+
from .parseq.system import PARSeq as ModelClass
|
53 |
+
elif 'trba' in key:
|
54 |
+
from .trba.system import TRBA as ModelClass
|
55 |
+
elif 'trbc' in key:
|
56 |
+
from .trba.system import TRBC as ModelClass
|
57 |
+
elif 'vitstr' in key:
|
58 |
+
from .vitstr.system import ViTSTR as ModelClass
|
59 |
+
else:
|
60 |
+
raise InvalidModelError("Unable to find model class for '{}'".format(key))
|
61 |
+
return ModelClass
|
62 |
+
|
63 |
+
|
64 |
+
def get_pretrained_weights(experiment):
|
65 |
+
try:
|
66 |
+
url = _WEIGHTS_URL[experiment]
|
67 |
+
except KeyError:
|
68 |
+
raise InvalidModelError("No pretrained weights found for '{}'".format(experiment)) from None
|
69 |
+
return torch.hub.load_state_dict_from_url(url=url, map_location='cpu', check_hash=True)
|
70 |
+
|
71 |
+
|
72 |
+
def create_model(experiment: str, pretrained: bool = False, **kwargs):
|
73 |
+
try:
|
74 |
+
config = _get_config(experiment, **kwargs)
|
75 |
+
except FileNotFoundError:
|
76 |
+
raise InvalidModelError("No configuration found for '{}'".format(experiment)) from None
|
77 |
+
ModelClass = _get_model_class(experiment)
|
78 |
+
model = ModelClass(**config)
|
79 |
+
if pretrained:
|
80 |
+
model.load_state_dict(get_pretrained_weights(experiment))
|
81 |
+
return model
|
82 |
+
|
83 |
+
|
84 |
+
def load_from_checkpoint(checkpoint_path: str, **kwargs):
|
85 |
+
if checkpoint_path.startswith('pretrained='):
|
86 |
+
model_id = checkpoint_path.split('=', maxsplit=1)[1]
|
87 |
+
model = create_model(model_id, True, **kwargs)
|
88 |
+
else:
|
89 |
+
ModelClass = _get_model_class(checkpoint_path)
|
90 |
+
model = ModelClass.load_from_checkpoint(checkpoint_path, **kwargs)
|
91 |
+
return model
|
92 |
+
|
93 |
+
|
94 |
+
def parse_model_args(args):
|
95 |
+
kwargs = {}
|
96 |
+
arg_types = {t.__name__: t for t in [int, float, str]}
|
97 |
+
arg_types['bool'] = lambda v: v.lower() == 'true' # special handling for bool
|
98 |
+
for arg in args:
|
99 |
+
name, value = arg.split('=', maxsplit=1)
|
100 |
+
name, arg_type = name.split(':', maxsplit=1)
|
101 |
+
kwargs[name] = arg_types[arg_type](value)
|
102 |
+
return kwargs
|
103 |
+
|
104 |
+
|
105 |
+
def init_weights(module: nn.Module, name: str = '', exclude: Sequence[str] = ()):
|
106 |
+
"""Initialize the weights using the typical initialization schemes used in SOTA models."""
|
107 |
+
if any(map(name.startswith, exclude)):
|
108 |
+
return
|
109 |
+
if isinstance(module, nn.Linear):
|
110 |
+
nn.init.trunc_normal_(module.weight, std=.02)
|
111 |
+
if module.bias is not None:
|
112 |
+
nn.init.zeros_(module.bias)
|
113 |
+
elif isinstance(module, nn.Embedding):
|
114 |
+
nn.init.trunc_normal_(module.weight, std=.02)
|
115 |
+
if module.padding_idx is not None:
|
116 |
+
module.weight.data[module.padding_idx].zero_()
|
117 |
+
elif isinstance(module, nn.Conv2d):
|
118 |
+
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
119 |
+
if module.bias is not None:
|
120 |
+
nn.init.zeros_(module.bias)
|
121 |
+
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
|
122 |
+
nn.init.ones_(module.weight)
|
123 |
+
nn.init.zeros_(module.bias)
|
strhub/models/__init__.py
ADDED
File without changes
|
strhub/models/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (141 Bytes). View file
|
|
strhub/models/__pycache__/base.cpython-37.pyc
ADDED
Binary file (7.54 kB). View file
|
|
strhub/models/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (4.69 kB). View file
|
|
strhub/models/abinet/LICENSE
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ABINet for non-commercial purposes
|
2 |
+
|
3 |
+
Copyright (c) 2021, USTC
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without
|
7 |
+
modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
10 |
+
list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
this list of conditions and the following disclaimer in the documentation
|
14 |
+
and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
17 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
18 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
19 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
20 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
21 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
22 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
23 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
24 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
strhub/models/abinet/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""
|
2 |
+
Fang, Shancheng, Hongtao, Xie, Yuxin, Wang, Zhendong, Mao, and Yongdong, Zhang.
|
3 |
+
"Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition." .
|
4 |
+
In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 7098-7107).2021.
|
5 |
+
|
6 |
+
https://arxiv.org/abs/2103.06495
|
7 |
+
|
8 |
+
All source files, except `system.py`, are based on the implementation listed below,
|
9 |
+
and hence are released under the license of the original.
|
10 |
+
|
11 |
+
Source: https://github.com/FangShancheng/ABINet
|
12 |
+
License: 2-clause BSD License (see included LICENSE file)
|
13 |
+
"""
|
strhub/models/abinet/attention.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .transformer import PositionalEncoding
|
5 |
+
|
6 |
+
|
7 |
+
class Attention(nn.Module):
|
8 |
+
def __init__(self, in_channels=512, max_length=25, n_feature=256):
|
9 |
+
super().__init__()
|
10 |
+
self.max_length = max_length
|
11 |
+
|
12 |
+
self.f0_embedding = nn.Embedding(max_length, in_channels)
|
13 |
+
self.w0 = nn.Linear(max_length, n_feature)
|
14 |
+
self.wv = nn.Linear(in_channels, in_channels)
|
15 |
+
self.we = nn.Linear(in_channels, max_length)
|
16 |
+
|
17 |
+
self.active = nn.Tanh()
|
18 |
+
self.softmax = nn.Softmax(dim=2)
|
19 |
+
|
20 |
+
def forward(self, enc_output):
|
21 |
+
enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2)
|
22 |
+
reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device)
|
23 |
+
reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S)
|
24 |
+
reading_order_embed = self.f0_embedding(reading_order) # b,25,512
|
25 |
+
|
26 |
+
t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256
|
27 |
+
t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512
|
28 |
+
|
29 |
+
attn = self.we(t) # b,256,25
|
30 |
+
attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256
|
31 |
+
g_output = torch.bmm(attn, enc_output) # b,25,512
|
32 |
+
return g_output, attn.view(*attn.shape[:2], 8, 32)
|
33 |
+
|
34 |
+
|
35 |
+
def encoder_layer(in_c, out_c, k=3, s=2, p=1):
|
36 |
+
return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
|
37 |
+
nn.BatchNorm2d(out_c),
|
38 |
+
nn.ReLU(True))
|
39 |
+
|
40 |
+
|
41 |
+
def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None):
|
42 |
+
align_corners = None if mode == 'nearest' else True
|
43 |
+
return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor,
|
44 |
+
mode=mode, align_corners=align_corners),
|
45 |
+
nn.Conv2d(in_c, out_c, k, s, p),
|
46 |
+
nn.BatchNorm2d(out_c),
|
47 |
+
nn.ReLU(True))
|
48 |
+
|
49 |
+
|
50 |
+
class PositionAttention(nn.Module):
|
51 |
+
def __init__(self, max_length, in_channels=512, num_channels=64,
|
52 |
+
h=8, w=32, mode='nearest', **kwargs):
|
53 |
+
super().__init__()
|
54 |
+
self.max_length = max_length
|
55 |
+
self.k_encoder = nn.Sequential(
|
56 |
+
encoder_layer(in_channels, num_channels, s=(1, 2)),
|
57 |
+
encoder_layer(num_channels, num_channels, s=(2, 2)),
|
58 |
+
encoder_layer(num_channels, num_channels, s=(2, 2)),
|
59 |
+
encoder_layer(num_channels, num_channels, s=(2, 2))
|
60 |
+
)
|
61 |
+
self.k_decoder = nn.Sequential(
|
62 |
+
decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
|
63 |
+
decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
|
64 |
+
decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
|
65 |
+
decoder_layer(num_channels, in_channels, size=(h, w), mode=mode)
|
66 |
+
)
|
67 |
+
|
68 |
+
self.pos_encoder = PositionalEncoding(in_channels, dropout=0., max_len=max_length)
|
69 |
+
self.project = nn.Linear(in_channels, in_channels)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
N, E, H, W = x.size()
|
73 |
+
k, v = x, x # (N, E, H, W)
|
74 |
+
|
75 |
+
# calculate key vector
|
76 |
+
features = []
|
77 |
+
for i in range(0, len(self.k_encoder)):
|
78 |
+
k = self.k_encoder[i](k)
|
79 |
+
features.append(k)
|
80 |
+
for i in range(0, len(self.k_decoder) - 1):
|
81 |
+
k = self.k_decoder[i](k)
|
82 |
+
k = k + features[len(self.k_decoder) - 2 - i]
|
83 |
+
k = self.k_decoder[-1](k)
|
84 |
+
|
85 |
+
# calculate query vector
|
86 |
+
# TODO q=f(q,k)
|
87 |
+
zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E)
|
88 |
+
q = self.pos_encoder(zeros) # (T, N, E)
|
89 |
+
q = q.permute(1, 0, 2) # (N, T, E)
|
90 |
+
q = self.project(q) # (N, T, E)
|
91 |
+
|
92 |
+
# calculate attention
|
93 |
+
attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
|
94 |
+
attn_scores = attn_scores / (E ** 0.5)
|
95 |
+
attn_scores = torch.softmax(attn_scores, dim=-1)
|
96 |
+
|
97 |
+
v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
|
98 |
+
attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
|
99 |
+
|
100 |
+
return attn_vecs, attn_scores.view(N, -1, H, W)
|
strhub/models/abinet/backbone.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from torch.nn import TransformerEncoderLayer, TransformerEncoder
|
3 |
+
|
4 |
+
from .resnet import resnet45
|
5 |
+
from .transformer import PositionalEncoding
|
6 |
+
|
7 |
+
|
8 |
+
class ResTranformer(nn.Module):
|
9 |
+
def __init__(self, d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', backbone_ln=2):
|
10 |
+
super().__init__()
|
11 |
+
self.resnet = resnet45()
|
12 |
+
self.pos_encoder = PositionalEncoding(d_model, max_len=8 * 32)
|
13 |
+
encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead,
|
14 |
+
dim_feedforward=d_inner, dropout=dropout, activation=activation)
|
15 |
+
self.transformer = TransformerEncoder(encoder_layer, backbone_ln)
|
16 |
+
|
17 |
+
def forward(self, images):
|
18 |
+
feature = self.resnet(images)
|
19 |
+
n, c, h, w = feature.shape
|
20 |
+
feature = feature.view(n, c, -1).permute(2, 0, 1)
|
21 |
+
feature = self.pos_encoder(feature)
|
22 |
+
feature = self.transformer(feature)
|
23 |
+
feature = feature.permute(1, 2, 0).view(n, c, h, w)
|
24 |
+
return feature
|
strhub/models/abinet/model.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class Model(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, dataset_max_length: int, null_label: int):
|
8 |
+
super().__init__()
|
9 |
+
self.max_length = dataset_max_length + 1 # additional stop token
|
10 |
+
self.null_label = null_label
|
11 |
+
|
12 |
+
def _get_length(self, logit, dim=-1):
|
13 |
+
""" Greed decoder to obtain length from logit"""
|
14 |
+
out = (logit.argmax(dim=-1) == self.null_label)
|
15 |
+
abn = out.any(dim)
|
16 |
+
out = ((out.cumsum(dim) == 1) & out).max(dim)[1]
|
17 |
+
out = out + 1 # additional end token
|
18 |
+
out = torch.where(abn, out, out.new_tensor(logit.shape[1], device=out.device))
|
19 |
+
return out
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def _get_padding_mask(length, max_length):
|
23 |
+
length = length.unsqueeze(-1)
|
24 |
+
grid = torch.arange(0, max_length, device=length.device).unsqueeze(0)
|
25 |
+
return grid >= length
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def _get_location_mask(sz, device=None):
|
29 |
+
mask = torch.eye(sz, device=device)
|
30 |
+
mask = mask.float().masked_fill(mask == 1, float('-inf'))
|
31 |
+
return mask
|
strhub/models/abinet/model_abinet_iter.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from .model_alignment import BaseAlignment
|
5 |
+
from .model_language import BCNLanguage
|
6 |
+
from .model_vision import BaseVision
|
7 |
+
|
8 |
+
|
9 |
+
class ABINetIterModel(nn.Module):
|
10 |
+
def __init__(self, dataset_max_length, null_label, num_classes, iter_size=1,
|
11 |
+
d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu',
|
12 |
+
v_loss_weight=1., v_attention='position', v_attention_mode='nearest',
|
13 |
+
v_backbone='transformer', v_num_layers=2,
|
14 |
+
l_loss_weight=1., l_num_layers=4, l_detach=True, l_use_self_attn=False,
|
15 |
+
a_loss_weight=1.):
|
16 |
+
super().__init__()
|
17 |
+
self.iter_size = iter_size
|
18 |
+
self.vision = BaseVision(dataset_max_length, null_label, num_classes, v_attention, v_attention_mode,
|
19 |
+
v_loss_weight, d_model, nhead, d_inner, dropout, activation, v_backbone, v_num_layers)
|
20 |
+
self.language = BCNLanguage(dataset_max_length, null_label, num_classes, d_model, nhead, d_inner, dropout,
|
21 |
+
activation, l_num_layers, l_detach, l_use_self_attn, l_loss_weight)
|
22 |
+
self.alignment = BaseAlignment(dataset_max_length, null_label, num_classes, d_model, a_loss_weight)
|
23 |
+
|
24 |
+
def forward(self, images):
|
25 |
+
v_res = self.vision(images)
|
26 |
+
a_res = v_res
|
27 |
+
all_l_res, all_a_res = [], []
|
28 |
+
for _ in range(self.iter_size):
|
29 |
+
tokens = torch.softmax(a_res['logits'], dim=-1)
|
30 |
+
lengths = a_res['pt_lengths']
|
31 |
+
lengths.clamp_(2, self.language.max_length) # TODO:move to langauge model
|
32 |
+
l_res = self.language(tokens, lengths)
|
33 |
+
all_l_res.append(l_res)
|
34 |
+
a_res = self.alignment(l_res['feature'], v_res['feature'])
|
35 |
+
all_a_res.append(a_res)
|
36 |
+
if self.training:
|
37 |
+
return all_a_res, all_l_res, v_res
|
38 |
+
else:
|
39 |
+
return a_res, all_l_res[-1], v_res
|
strhub/models/abinet/model_alignment.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .model import Model
|
5 |
+
|
6 |
+
|
7 |
+
class BaseAlignment(Model):
|
8 |
+
def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, loss_weight=1.0):
|
9 |
+
super().__init__(dataset_max_length, null_label)
|
10 |
+
self.loss_weight = loss_weight
|
11 |
+
self.w_att = nn.Linear(2 * d_model, d_model)
|
12 |
+
self.cls = nn.Linear(d_model, num_classes)
|
13 |
+
|
14 |
+
def forward(self, l_feature, v_feature):
|
15 |
+
"""
|
16 |
+
Args:
|
17 |
+
l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
|
18 |
+
v_feature: (N, T, E) shape the same as l_feature
|
19 |
+
"""
|
20 |
+
f = torch.cat((l_feature, v_feature), dim=2)
|
21 |
+
f_att = torch.sigmoid(self.w_att(f))
|
22 |
+
output = f_att * v_feature + (1 - f_att) * l_feature
|
23 |
+
|
24 |
+
logits = self.cls(output) # (N, T, C)
|
25 |
+
pt_lengths = self._get_length(logits)
|
26 |
+
|
27 |
+
return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight': self.loss_weight,
|
28 |
+
'name': 'alignment'}
|
strhub/models/abinet/model_language.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from torch.nn import TransformerDecoder
|
3 |
+
|
4 |
+
from .model import Model
|
5 |
+
from .transformer import PositionalEncoding, TransformerDecoderLayer
|
6 |
+
|
7 |
+
|
8 |
+
class BCNLanguage(Model):
|
9 |
+
def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, nhead=8, d_inner=2048, dropout=0.1,
|
10 |
+
activation='relu', num_layers=4, detach=True, use_self_attn=False, loss_weight=1.0,
|
11 |
+
global_debug=False):
|
12 |
+
super().__init__(dataset_max_length, null_label)
|
13 |
+
self.detach = detach
|
14 |
+
self.loss_weight = loss_weight
|
15 |
+
self.proj = nn.Linear(num_classes, d_model, False)
|
16 |
+
self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length)
|
17 |
+
self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length)
|
18 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout,
|
19 |
+
activation, self_attn=use_self_attn, debug=global_debug)
|
20 |
+
self.model = TransformerDecoder(decoder_layer, num_layers)
|
21 |
+
self.cls = nn.Linear(d_model, num_classes)
|
22 |
+
|
23 |
+
def forward(self, tokens, lengths):
|
24 |
+
"""
|
25 |
+
Args:
|
26 |
+
tokens: (N, T, C) where T is length, N is batch size and C is classes number
|
27 |
+
lengths: (N,)
|
28 |
+
"""
|
29 |
+
if self.detach:
|
30 |
+
tokens = tokens.detach()
|
31 |
+
embed = self.proj(tokens) # (N, T, E)
|
32 |
+
embed = embed.permute(1, 0, 2) # (T, N, E)
|
33 |
+
embed = self.token_encoder(embed) # (T, N, E)
|
34 |
+
padding_mask = self._get_padding_mask(lengths, self.max_length)
|
35 |
+
|
36 |
+
zeros = embed.new_zeros(*embed.shape)
|
37 |
+
qeury = self.pos_encoder(zeros)
|
38 |
+
location_mask = self._get_location_mask(self.max_length, tokens.device)
|
39 |
+
output = self.model(qeury, embed,
|
40 |
+
tgt_key_padding_mask=padding_mask,
|
41 |
+
memory_mask=location_mask,
|
42 |
+
memory_key_padding_mask=padding_mask) # (T, N, E)
|
43 |
+
output = output.permute(1, 0, 2) # (N, T, E)
|
44 |
+
|
45 |
+
logits = self.cls(output) # (N, T, C)
|
46 |
+
pt_lengths = self._get_length(logits)
|
47 |
+
|
48 |
+
res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths,
|
49 |
+
'loss_weight': self.loss_weight, 'name': 'language'}
|
50 |
+
return res
|
strhub/models/abinet/model_vision.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
from .attention import PositionAttention, Attention
|
4 |
+
from .backbone import ResTranformer
|
5 |
+
from .model import Model
|
6 |
+
from .resnet import resnet45
|
7 |
+
|
8 |
+
|
9 |
+
class BaseVision(Model):
|
10 |
+
def __init__(self, dataset_max_length, null_label, num_classes,
|
11 |
+
attention='position', attention_mode='nearest', loss_weight=1.0,
|
12 |
+
d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu',
|
13 |
+
backbone='transformer', backbone_ln=2):
|
14 |
+
super().__init__(dataset_max_length, null_label)
|
15 |
+
self.loss_weight = loss_weight
|
16 |
+
self.out_channels = d_model
|
17 |
+
|
18 |
+
if backbone == 'transformer':
|
19 |
+
self.backbone = ResTranformer(d_model, nhead, d_inner, dropout, activation, backbone_ln)
|
20 |
+
else:
|
21 |
+
self.backbone = resnet45()
|
22 |
+
|
23 |
+
if attention == 'position':
|
24 |
+
self.attention = PositionAttention(
|
25 |
+
max_length=self.max_length,
|
26 |
+
mode=attention_mode
|
27 |
+
)
|
28 |
+
elif attention == 'attention':
|
29 |
+
self.attention = Attention(
|
30 |
+
max_length=self.max_length,
|
31 |
+
n_feature=8 * 32,
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
raise ValueError(f'invalid attention: {attention}')
|
35 |
+
|
36 |
+
self.cls = nn.Linear(self.out_channels, num_classes)
|
37 |
+
|
38 |
+
def forward(self, images):
|
39 |
+
features = self.backbone(images) # (N, E, H, W)
|
40 |
+
attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W)
|
41 |
+
logits = self.cls(attn_vecs) # (N, T, C)
|
42 |
+
pt_lengths = self._get_length(logits)
|
43 |
+
|
44 |
+
return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths,
|
45 |
+
'attn_scores': attn_scores, 'loss_weight': self.loss_weight, 'name': 'vision'}
|
strhub/models/abinet/resnet.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Callable
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from torchvision.models import resnet
|
6 |
+
|
7 |
+
|
8 |
+
class BasicBlock(resnet.BasicBlock):
|
9 |
+
|
10 |
+
def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None,
|
11 |
+
groups: int = 1, base_width: int = 64, dilation: int = 1,
|
12 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
|
13 |
+
super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)
|
14 |
+
self.conv1 = resnet.conv1x1(inplanes, planes)
|
15 |
+
self.conv2 = resnet.conv3x3(planes, planes, stride)
|
16 |
+
|
17 |
+
|
18 |
+
class ResNet(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self, block, layers):
|
21 |
+
super().__init__()
|
22 |
+
self.inplanes = 32
|
23 |
+
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1,
|
24 |
+
bias=False)
|
25 |
+
self.bn1 = nn.BatchNorm2d(32)
|
26 |
+
self.relu = nn.ReLU(inplace=True)
|
27 |
+
|
28 |
+
self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
|
29 |
+
self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
|
30 |
+
self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
|
31 |
+
self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
|
32 |
+
self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
|
33 |
+
|
34 |
+
for m in self.modules():
|
35 |
+
if isinstance(m, nn.Conv2d):
|
36 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
37 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
38 |
+
elif isinstance(m, nn.BatchNorm2d):
|
39 |
+
m.weight.data.fill_(1)
|
40 |
+
m.bias.data.zero_()
|
41 |
+
|
42 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
43 |
+
downsample = None
|
44 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
45 |
+
downsample = nn.Sequential(
|
46 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
47 |
+
kernel_size=1, stride=stride, bias=False),
|
48 |
+
nn.BatchNorm2d(planes * block.expansion),
|
49 |
+
)
|
50 |
+
|
51 |
+
layers = []
|
52 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
53 |
+
self.inplanes = planes * block.expansion
|
54 |
+
for i in range(1, blocks):
|
55 |
+
layers.append(block(self.inplanes, planes))
|
56 |
+
|
57 |
+
return nn.Sequential(*layers)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x = self.conv1(x)
|
61 |
+
x = self.bn1(x)
|
62 |
+
x = self.relu(x)
|
63 |
+
x = self.layer1(x)
|
64 |
+
x = self.layer2(x)
|
65 |
+
x = self.layer3(x)
|
66 |
+
x = self.layer4(x)
|
67 |
+
x = self.layer5(x)
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
def resnet45():
|
72 |
+
return ResNet(BasicBlock, [3, 4, 6, 6, 3])
|
strhub/models/abinet/system.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import logging
|
17 |
+
import math
|
18 |
+
from typing import Any, Tuple, List, Optional
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from torch import Tensor, nn
|
23 |
+
from torch.optim import AdamW
|
24 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
25 |
+
|
26 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
27 |
+
from timm.optim.optim_factory import param_groups_weight_decay
|
28 |
+
|
29 |
+
from strhub.models.base import CrossEntropySystem
|
30 |
+
from strhub.models.utils import init_weights
|
31 |
+
from .model_abinet_iter import ABINetIterModel as Model
|
32 |
+
|
33 |
+
log = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
class ABINet(CrossEntropySystem):
|
37 |
+
|
38 |
+
def __init__(self, charset_train: str, charset_test: str, max_label_length: int,
|
39 |
+
batch_size: int, lr: float, warmup_pct: float, weight_decay: float,
|
40 |
+
iter_size: int, d_model: int, nhead: int, d_inner: int, dropout: float, activation: str,
|
41 |
+
v_loss_weight: float, v_attention: str, v_attention_mode: str, v_backbone: str, v_num_layers: int,
|
42 |
+
l_loss_weight: float, l_num_layers: int, l_detach: bool, l_use_self_attn: bool,
|
43 |
+
l_lr: float, a_loss_weight: float, lm_only: bool = False, **kwargs) -> None:
|
44 |
+
super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
|
45 |
+
self.scheduler = None
|
46 |
+
self.save_hyperparameters()
|
47 |
+
self.max_label_length = max_label_length
|
48 |
+
self.num_classes = len(self.tokenizer) - 2 # We don't predict <bos> nor <pad>
|
49 |
+
self.model = Model(max_label_length, self.eos_id, self.num_classes, iter_size, d_model, nhead, d_inner,
|
50 |
+
dropout, activation, v_loss_weight, v_attention, v_attention_mode, v_backbone, v_num_layers,
|
51 |
+
l_loss_weight, l_num_layers, l_detach, l_use_self_attn, a_loss_weight)
|
52 |
+
self.model.apply(init_weights)
|
53 |
+
# FIXME: doesn't support resumption from checkpoint yet
|
54 |
+
self._reset_alignment = True
|
55 |
+
self._reset_optimizers = True
|
56 |
+
self.l_lr = l_lr
|
57 |
+
self.lm_only = lm_only
|
58 |
+
# Train LM only. Freeze other submodels.
|
59 |
+
if lm_only:
|
60 |
+
self.l_lr = lr # for tuning
|
61 |
+
self.model.vision.requires_grad_(False)
|
62 |
+
self.model.alignment.requires_grad_(False)
|
63 |
+
|
64 |
+
@property
|
65 |
+
def _pretraining(self):
|
66 |
+
# In the original work, VM was pretrained for 8 epochs while full model was trained for an additional 10 epochs.
|
67 |
+
total_steps = self.trainer.estimated_stepping_batches * self.trainer.accumulate_grad_batches
|
68 |
+
return self.global_step < (8 / (8 + 10)) * total_steps
|
69 |
+
|
70 |
+
@torch.jit.ignore
|
71 |
+
def no_weight_decay(self):
|
72 |
+
return {'model.language.proj.weight'}
|
73 |
+
|
74 |
+
def _add_weight_decay(self, model: nn.Module, skip_list=()):
|
75 |
+
if self.weight_decay:
|
76 |
+
return param_groups_weight_decay(model, self.weight_decay, skip_list)
|
77 |
+
else:
|
78 |
+
return [{'params': model.parameters()}]
|
79 |
+
|
80 |
+
def configure_optimizers(self):
|
81 |
+
agb = self.trainer.accumulate_grad_batches
|
82 |
+
# Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
|
83 |
+
lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.
|
84 |
+
lr = lr_scale * self.lr
|
85 |
+
l_lr = lr_scale * self.l_lr
|
86 |
+
params = []
|
87 |
+
params.extend(self._add_weight_decay(self.model.vision))
|
88 |
+
params.extend(self._add_weight_decay(self.model.alignment))
|
89 |
+
# We use a different learning rate for the LM.
|
90 |
+
for p in self._add_weight_decay(self.model.language, ('proj.weight',)):
|
91 |
+
p['lr'] = l_lr
|
92 |
+
params.append(p)
|
93 |
+
max_lr = [p.get('lr', lr) for p in params]
|
94 |
+
optim = AdamW(params, lr)
|
95 |
+
self.scheduler = OneCycleLR(optim, max_lr, self.trainer.estimated_stepping_batches,
|
96 |
+
pct_start=self.warmup_pct, cycle_momentum=False)
|
97 |
+
return {'optimizer': optim, 'lr_scheduler': {'scheduler': self.scheduler, 'interval': 'step'}}
|
98 |
+
|
99 |
+
def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
|
100 |
+
max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
|
101 |
+
logits = self.model.forward(images)[0]['logits']
|
102 |
+
return logits[:, :max_length + 1] # truncate
|
103 |
+
|
104 |
+
def calc_loss(self, targets, *res_lists) -> Tensor:
|
105 |
+
total_loss = 0
|
106 |
+
for res_list in res_lists:
|
107 |
+
loss = 0
|
108 |
+
if isinstance(res_list, dict):
|
109 |
+
res_list = [res_list]
|
110 |
+
for res in res_list:
|
111 |
+
logits = res['logits'].flatten(end_dim=1)
|
112 |
+
loss += F.cross_entropy(logits, targets.flatten(), ignore_index=self.pad_id)
|
113 |
+
loss /= len(res_list)
|
114 |
+
self.log('loss_' + res_list[0]['name'], loss)
|
115 |
+
total_loss += res_list[0]['loss_weight'] * loss
|
116 |
+
return total_loss
|
117 |
+
|
118 |
+
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
|
119 |
+
if not self._pretraining and self._reset_optimizers:
|
120 |
+
log.info('Pretraining ends. Updating base LRs.')
|
121 |
+
self._reset_optimizers = False
|
122 |
+
# Make base_lr the same for all groups
|
123 |
+
base_lr = self.scheduler.base_lrs[0] # base_lr of group 0 - VM
|
124 |
+
self.scheduler.base_lrs = [base_lr] * len(self.scheduler.base_lrs)
|
125 |
+
|
126 |
+
def _prepare_inputs_and_targets(self, labels):
|
127 |
+
# Use dummy label to ensure sequence length is constant.
|
128 |
+
dummy = ['0' * self.max_label_length]
|
129 |
+
targets = self.tokenizer.encode(dummy + list(labels), self.device)[1:]
|
130 |
+
targets = targets[:, 1:] # remove <bos>. Unused here.
|
131 |
+
# Inputs are padded with eos_id
|
132 |
+
inputs = torch.where(targets == self.pad_id, self.eos_id, targets)
|
133 |
+
inputs = F.one_hot(inputs, self.num_classes).float()
|
134 |
+
lengths = torch.as_tensor(list(map(len, labels)), device=self.device) + 1 # +1 for eos
|
135 |
+
return inputs, lengths, targets
|
136 |
+
|
137 |
+
def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
138 |
+
images, labels = batch
|
139 |
+
inputs, lengths, targets = self._prepare_inputs_and_targets(labels)
|
140 |
+
if self.lm_only:
|
141 |
+
l_res = self.model.language(inputs, lengths)
|
142 |
+
loss = self.calc_loss(targets, l_res)
|
143 |
+
# Pretrain submodels independently first
|
144 |
+
elif self._pretraining:
|
145 |
+
# Vision
|
146 |
+
v_res = self.model.vision(images)
|
147 |
+
# Language
|
148 |
+
l_res = self.model.language(inputs, lengths)
|
149 |
+
# We also train the alignment model to 'satisfy' DDP requirements (all parameters should be used).
|
150 |
+
# We'll reset its parameters prior to joint training.
|
151 |
+
a_res = self.model.alignment(l_res['feature'].detach(), v_res['feature'].detach())
|
152 |
+
loss = self.calc_loss(targets, v_res, l_res, a_res)
|
153 |
+
else:
|
154 |
+
# Reset alignment model's parameters once prior to full model training.
|
155 |
+
if self._reset_alignment:
|
156 |
+
log.info('Pretraining ends. Resetting alignment model.')
|
157 |
+
self._reset_alignment = False
|
158 |
+
self.model.alignment.apply(init_weights)
|
159 |
+
all_a_res, all_l_res, v_res = self.model.forward(images)
|
160 |
+
loss = self.calc_loss(targets, v_res, all_l_res, all_a_res)
|
161 |
+
self.log('loss', loss)
|
162 |
+
return loss
|
163 |
+
|
164 |
+
def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
|
165 |
+
if self.lm_only:
|
166 |
+
inputs, lengths, targets = self._prepare_inputs_and_targets(labels)
|
167 |
+
l_res = self.model.language(inputs, lengths)
|
168 |
+
loss = self.calc_loss(targets, l_res)
|
169 |
+
loss_numel = (targets != self.pad_id).sum()
|
170 |
+
return l_res['logits'], loss, loss_numel
|
171 |
+
else:
|
172 |
+
return super().forward_logits_loss(images, labels)
|
strhub/models/abinet/transformer.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn.modules.transformer import _get_activation_fn
|
7 |
+
|
8 |
+
|
9 |
+
class TransformerDecoderLayer(nn.Module):
|
10 |
+
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
11 |
+
This standard decoder layer is based on the paper "Attention Is All You Need".
|
12 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
13 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
14 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
15 |
+
in a different way during application.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
d_model: the number of expected features in the input (required).
|
19 |
+
nhead: the number of heads in the multiheadattention models (required).
|
20 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
21 |
+
dropout: the dropout value (default=0.1).
|
22 |
+
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
23 |
+
|
24 |
+
Examples::
|
25 |
+
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
26 |
+
>>> memory = torch.rand(10, 32, 512)
|
27 |
+
>>> tgt = torch.rand(20, 32, 512)
|
28 |
+
>>> out = decoder_layer(tgt, memory)
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
32 |
+
activation="relu", self_attn=True, siamese=False, debug=False):
|
33 |
+
super().__init__()
|
34 |
+
self.has_self_attn, self.siamese = self_attn, siamese
|
35 |
+
self.debug = debug
|
36 |
+
if self.has_self_attn:
|
37 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
38 |
+
self.norm1 = nn.LayerNorm(d_model)
|
39 |
+
self.dropout1 = nn.Dropout(dropout)
|
40 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
41 |
+
# Implementation of Feedforward model
|
42 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
43 |
+
self.dropout = nn.Dropout(dropout)
|
44 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
45 |
+
|
46 |
+
self.norm2 = nn.LayerNorm(d_model)
|
47 |
+
self.norm3 = nn.LayerNorm(d_model)
|
48 |
+
self.dropout2 = nn.Dropout(dropout)
|
49 |
+
self.dropout3 = nn.Dropout(dropout)
|
50 |
+
if self.siamese:
|
51 |
+
self.multihead_attn2 = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
52 |
+
|
53 |
+
self.activation = _get_activation_fn(activation)
|
54 |
+
|
55 |
+
def __setstate__(self, state):
|
56 |
+
if 'activation' not in state:
|
57 |
+
state['activation'] = F.relu
|
58 |
+
super().__setstate__(state)
|
59 |
+
|
60 |
+
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
|
61 |
+
tgt_key_padding_mask=None, memory_key_padding_mask=None,
|
62 |
+
memory2=None, memory_mask2=None, memory_key_padding_mask2=None):
|
63 |
+
# type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
|
64 |
+
r"""Pass the inputs (and mask) through the decoder layer.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
tgt: the sequence to the decoder layer (required).
|
68 |
+
memory: the sequence from the last layer of the encoder (required).
|
69 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
70 |
+
memory_mask: the mask for the memory sequence (optional).
|
71 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
72 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
73 |
+
|
74 |
+
Shape:
|
75 |
+
see the docs in Transformer class.
|
76 |
+
"""
|
77 |
+
if self.has_self_attn:
|
78 |
+
tgt2, attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
|
79 |
+
key_padding_mask=tgt_key_padding_mask)
|
80 |
+
tgt = tgt + self.dropout1(tgt2)
|
81 |
+
tgt = self.norm1(tgt)
|
82 |
+
if self.debug: self.attn = attn
|
83 |
+
tgt2, attn2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
|
84 |
+
key_padding_mask=memory_key_padding_mask)
|
85 |
+
if self.debug: self.attn2 = attn2
|
86 |
+
|
87 |
+
if self.siamese:
|
88 |
+
tgt3, attn3 = self.multihead_attn2(tgt, memory2, memory2, attn_mask=memory_mask2,
|
89 |
+
key_padding_mask=memory_key_padding_mask2)
|
90 |
+
tgt = tgt + self.dropout2(tgt3)
|
91 |
+
if self.debug: self.attn3 = attn3
|
92 |
+
|
93 |
+
tgt = tgt + self.dropout2(tgt2)
|
94 |
+
tgt = self.norm2(tgt)
|
95 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
96 |
+
tgt = tgt + self.dropout3(tgt2)
|
97 |
+
tgt = self.norm3(tgt)
|
98 |
+
|
99 |
+
return tgt
|
100 |
+
|
101 |
+
|
102 |
+
class PositionalEncoding(nn.Module):
|
103 |
+
r"""Inject some information about the relative or absolute position of the tokens
|
104 |
+
in the sequence. The positional encodings have the same dimension as
|
105 |
+
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
106 |
+
functions of different frequencies.
|
107 |
+
.. math::
|
108 |
+
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
|
109 |
+
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
|
110 |
+
\text{where pos is the word position and i is the embed idx)
|
111 |
+
Args:
|
112 |
+
d_model: the embed dim (required).
|
113 |
+
dropout: the dropout value (default=0.1).
|
114 |
+
max_len: the max. length of the incoming sequence (default=5000).
|
115 |
+
Examples:
|
116 |
+
>>> pos_encoder = PositionalEncoding(d_model)
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
120 |
+
super().__init__()
|
121 |
+
self.dropout = nn.Dropout(p=dropout)
|
122 |
+
|
123 |
+
pe = torch.zeros(max_len, d_model)
|
124 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
125 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
126 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
127 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
128 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
129 |
+
self.register_buffer('pe', pe)
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
r"""Inputs of forward function
|
133 |
+
Args:
|
134 |
+
x: the sequence fed to the positional encoder model (required).
|
135 |
+
Shape:
|
136 |
+
x: [sequence length, batch size, embed dim]
|
137 |
+
output: [sequence length, batch size, embed dim]
|
138 |
+
Examples:
|
139 |
+
>>> output = pos_encoder(x)
|
140 |
+
"""
|
141 |
+
|
142 |
+
x = x + self.pe[:x.size(0), :]
|
143 |
+
return self.dropout(x)
|
strhub/models/base.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import math
|
17 |
+
from abc import ABC, abstractmethod
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import Optional, Tuple, List
|
20 |
+
|
21 |
+
import pytorch_lightning as pl
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from nltk import edit_distance
|
25 |
+
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
|
26 |
+
from timm.optim import create_optimizer_v2
|
27 |
+
from torch import Tensor
|
28 |
+
from torch.optim import Optimizer
|
29 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
30 |
+
|
31 |
+
from strhub.data.utils import CharsetAdapter, CTCTokenizer, Tokenizer, BaseTokenizer
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class BatchResult:
|
36 |
+
num_samples: int
|
37 |
+
correct: int
|
38 |
+
ned: float
|
39 |
+
confidence: float
|
40 |
+
label_length: int
|
41 |
+
loss: Tensor
|
42 |
+
loss_numel: int
|
43 |
+
|
44 |
+
|
45 |
+
class BaseSystem(pl.LightningModule, ABC):
|
46 |
+
|
47 |
+
def __init__(self, tokenizer: BaseTokenizer, charset_test: str,
|
48 |
+
batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None:
|
49 |
+
super().__init__()
|
50 |
+
self.tokenizer = tokenizer
|
51 |
+
self.charset_adapter = CharsetAdapter(charset_test)
|
52 |
+
self.batch_size = batch_size
|
53 |
+
self.lr = lr
|
54 |
+
self.warmup_pct = warmup_pct
|
55 |
+
self.weight_decay = weight_decay
|
56 |
+
|
57 |
+
@abstractmethod
|
58 |
+
def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
|
59 |
+
"""Inference
|
60 |
+
|
61 |
+
Args:
|
62 |
+
images: Batch of images. Shape: N, Ch, H, W
|
63 |
+
max_length: Max sequence length of the output. If None, will use default.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
|
67 |
+
"""
|
68 |
+
raise NotImplementedError
|
69 |
+
|
70 |
+
@abstractmethod
|
71 |
+
def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
|
72 |
+
"""Like forward(), but also computes the loss (calls forward() internally).
|
73 |
+
|
74 |
+
Args:
|
75 |
+
images: Batch of images. Shape: N, Ch, H, W
|
76 |
+
labels: Text labels of the images
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
|
80 |
+
loss: mean loss for the batch
|
81 |
+
loss_numel: number of elements the loss was calculated from
|
82 |
+
"""
|
83 |
+
raise NotImplementedError
|
84 |
+
|
85 |
+
def configure_optimizers(self):
|
86 |
+
agb = self.trainer.accumulate_grad_batches
|
87 |
+
# Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
|
88 |
+
lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.
|
89 |
+
lr = lr_scale * self.lr
|
90 |
+
optim = create_optimizer_v2(self, 'adamw', lr, self.weight_decay)
|
91 |
+
sched = OneCycleLR(optim, lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct,
|
92 |
+
cycle_momentum=False)
|
93 |
+
return {'optimizer': optim, 'lr_scheduler': {'scheduler': sched, 'interval': 'step'}}
|
94 |
+
|
95 |
+
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
|
96 |
+
optimizer.zero_grad(set_to_none=True)
|
97 |
+
|
98 |
+
def _eval_step(self, batch, validation: bool) -> Optional[STEP_OUTPUT]:
|
99 |
+
images, labels = batch
|
100 |
+
|
101 |
+
correct = 0
|
102 |
+
total = 0
|
103 |
+
ned = 0
|
104 |
+
confidence = 0
|
105 |
+
label_length = 0
|
106 |
+
if validation:
|
107 |
+
logits, loss, loss_numel = self.forward_logits_loss(images, labels)
|
108 |
+
else:
|
109 |
+
# At test-time, we shouldn't specify a max_label_length because the test-time charset used
|
110 |
+
# might be different from the train-time charset. max_label_length in eval_logits_loss() is computed
|
111 |
+
# based on the transformed label, which could be wrong if the actual gt label contains characters existing
|
112 |
+
# in the train-time charset but not in the test-time charset. For example, "aishahaleyes.blogspot.com"
|
113 |
+
# is exactly 25 characters, but if processed by CharsetAdapter for the 36-char set, it becomes 23 characters
|
114 |
+
# long only, which sets max_label_length = 23. This will cause the model prediction to be truncated.
|
115 |
+
logits = self.forward(images)
|
116 |
+
loss = loss_numel = None # Only used for validation; not needed at test-time.
|
117 |
+
|
118 |
+
probs = logits.softmax(-1)
|
119 |
+
preds, probs = self.tokenizer.decode(probs)
|
120 |
+
for pred, prob, gt in zip(preds, probs, labels):
|
121 |
+
confidence += prob.prod().item()
|
122 |
+
pred = self.charset_adapter(pred)
|
123 |
+
# Follow ICDAR 2019 definition of N.E.D.
|
124 |
+
ned += edit_distance(pred, gt) / max(len(pred), len(gt))
|
125 |
+
if pred == gt:
|
126 |
+
correct += 1
|
127 |
+
total += 1
|
128 |
+
label_length += len(pred)
|
129 |
+
return dict(output=BatchResult(total, correct, ned, confidence, label_length, loss, loss_numel))
|
130 |
+
|
131 |
+
@staticmethod
|
132 |
+
def _aggregate_results(outputs: EPOCH_OUTPUT) -> Tuple[float, float, float]:
|
133 |
+
if not outputs:
|
134 |
+
return 0., 0., 0.
|
135 |
+
total_loss = 0
|
136 |
+
total_loss_numel = 0
|
137 |
+
total_n_correct = 0
|
138 |
+
total_norm_ED = 0
|
139 |
+
total_size = 0
|
140 |
+
for result in outputs:
|
141 |
+
result = result['output']
|
142 |
+
total_loss += result.loss_numel * result.loss
|
143 |
+
total_loss_numel += result.loss_numel
|
144 |
+
total_n_correct += result.correct
|
145 |
+
total_norm_ED += result.ned
|
146 |
+
total_size += result.num_samples
|
147 |
+
acc = total_n_correct / total_size
|
148 |
+
ned = (1 - total_norm_ED / total_size)
|
149 |
+
loss = total_loss / total_loss_numel
|
150 |
+
return acc, ned, loss
|
151 |
+
|
152 |
+
def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
|
153 |
+
return self._eval_step(batch, True)
|
154 |
+
|
155 |
+
def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
|
156 |
+
acc, ned, loss = self._aggregate_results(outputs)
|
157 |
+
self.log('val_accuracy', 100 * acc, sync_dist=True)
|
158 |
+
self.log('val_NED', 100 * ned, sync_dist=True)
|
159 |
+
self.log('val_loss', loss, sync_dist=True)
|
160 |
+
self.log('hp_metric', acc, sync_dist=True)
|
161 |
+
|
162 |
+
def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
|
163 |
+
return self._eval_step(batch, False)
|
164 |
+
|
165 |
+
|
166 |
+
class CrossEntropySystem(BaseSystem):
|
167 |
+
|
168 |
+
def __init__(self, charset_train: str, charset_test: str,
|
169 |
+
batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None:
|
170 |
+
tokenizer = Tokenizer(charset_train)
|
171 |
+
super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
|
172 |
+
self.bos_id = tokenizer.bos_id
|
173 |
+
self.eos_id = tokenizer.eos_id
|
174 |
+
self.pad_id = tokenizer.pad_id
|
175 |
+
|
176 |
+
def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
|
177 |
+
targets = self.tokenizer.encode(labels, self.device)
|
178 |
+
targets = targets[:, 1:] # Discard <bos>
|
179 |
+
max_len = targets.shape[1] - 1 # exclude <eos> from count
|
180 |
+
logits = self.forward(images, max_len)
|
181 |
+
loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id)
|
182 |
+
loss_numel = (targets != self.pad_id).sum()
|
183 |
+
return logits, loss, loss_numel
|
184 |
+
|
185 |
+
|
186 |
+
class CTCSystem(BaseSystem):
|
187 |
+
|
188 |
+
def __init__(self, charset_train: str, charset_test: str,
|
189 |
+
batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None:
|
190 |
+
tokenizer = CTCTokenizer(charset_train)
|
191 |
+
super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
|
192 |
+
self.blank_id = tokenizer.blank_id
|
193 |
+
|
194 |
+
def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
|
195 |
+
targets = self.tokenizer.encode(labels, self.device)
|
196 |
+
logits = self.forward(images)
|
197 |
+
log_probs = logits.log_softmax(-1).transpose(0, 1) # swap batch and seq. dims
|
198 |
+
T, N, _ = log_probs.shape
|
199 |
+
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long, device=self.device)
|
200 |
+
target_lengths = torch.as_tensor(list(map(len, labels)), dtype=torch.long, device=self.device)
|
201 |
+
loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_id, zero_infinity=True)
|
202 |
+
return logits, loss, N
|
strhub/models/crnn/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The MIT License (MIT)
|
2 |
+
|
3 |
+
Copyright (c) 2017 Jieru Mei <meijieru@gmail.com>
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
strhub/models/crnn/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""
|
2 |
+
Shi, Baoguang, Xiang Bai, and Cong Yao.
|
3 |
+
"An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition."
|
4 |
+
IEEE transactions on pattern analysis and machine intelligence 39, no. 11 (2016): 2298-2304.
|
5 |
+
|
6 |
+
https://arxiv.org/abs/1507.05717
|
7 |
+
|
8 |
+
All source files, except `system.py`, are based on the implementation listed below,
|
9 |
+
and hence are released under the license of the original.
|
10 |
+
|
11 |
+
Source: https://github.com/meijieru/crnn.pytorch
|
12 |
+
License: MIT License (see included LICENSE file)
|
13 |
+
"""
|
strhub/models/crnn/model.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from strhub.models.modules import BidirectionalLSTM
|
4 |
+
|
5 |
+
|
6 |
+
class CRNN(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, img_h, nc, nclass, nh, leaky_relu=False):
|
9 |
+
super().__init__()
|
10 |
+
assert img_h % 16 == 0, 'img_h has to be a multiple of 16'
|
11 |
+
|
12 |
+
ks = [3, 3, 3, 3, 3, 3, 2]
|
13 |
+
ps = [1, 1, 1, 1, 1, 1, 0]
|
14 |
+
ss = [1, 1, 1, 1, 1, 1, 1]
|
15 |
+
nm = [64, 128, 256, 256, 512, 512, 512]
|
16 |
+
|
17 |
+
cnn = nn.Sequential()
|
18 |
+
|
19 |
+
def convRelu(i, batchNormalization=False):
|
20 |
+
nIn = nc if i == 0 else nm[i - 1]
|
21 |
+
nOut = nm[i]
|
22 |
+
cnn.add_module('conv{0}'.format(i),
|
23 |
+
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i], bias=not batchNormalization))
|
24 |
+
if batchNormalization:
|
25 |
+
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
|
26 |
+
if leaky_relu:
|
27 |
+
cnn.add_module('relu{0}'.format(i),
|
28 |
+
nn.LeakyReLU(0.2, inplace=True))
|
29 |
+
else:
|
30 |
+
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
|
31 |
+
|
32 |
+
convRelu(0)
|
33 |
+
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
|
34 |
+
convRelu(1)
|
35 |
+
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
|
36 |
+
convRelu(2, True)
|
37 |
+
convRelu(3)
|
38 |
+
cnn.add_module('pooling{0}'.format(2),
|
39 |
+
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
|
40 |
+
convRelu(4, True)
|
41 |
+
convRelu(5)
|
42 |
+
cnn.add_module('pooling{0}'.format(3),
|
43 |
+
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
|
44 |
+
convRelu(6, True) # 512x1x16
|
45 |
+
|
46 |
+
self.cnn = cnn
|
47 |
+
self.rnn = nn.Sequential(
|
48 |
+
BidirectionalLSTM(512, nh, nh),
|
49 |
+
BidirectionalLSTM(nh, nh, nclass))
|
50 |
+
|
51 |
+
def forward(self, input):
|
52 |
+
# conv features
|
53 |
+
conv = self.cnn(input)
|
54 |
+
b, c, h, w = conv.size()
|
55 |
+
assert h == 1, 'the height of conv must be 1'
|
56 |
+
conv = conv.squeeze(2)
|
57 |
+
conv = conv.transpose(1, 2) # [b, w, c]
|
58 |
+
|
59 |
+
# rnn features
|
60 |
+
output = self.rnn(conv)
|
61 |
+
|
62 |
+
return output
|
strhub/models/crnn/system.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Sequence, Optional
|
17 |
+
|
18 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
19 |
+
from torch import Tensor
|
20 |
+
|
21 |
+
from strhub.models.base import CTCSystem
|
22 |
+
from strhub.models.utils import init_weights
|
23 |
+
from .model import CRNN as Model
|
24 |
+
|
25 |
+
|
26 |
+
class CRNN(CTCSystem):
|
27 |
+
|
28 |
+
def __init__(self, charset_train: str, charset_test: str, max_label_length: int,
|
29 |
+
batch_size: int, lr: float, warmup_pct: float, weight_decay: float,
|
30 |
+
img_size: Sequence[int], hidden_size: int, leaky_relu: bool, **kwargs) -> None:
|
31 |
+
super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
|
32 |
+
self.save_hyperparameters()
|
33 |
+
self.model = Model(img_size[0], 3, len(self.tokenizer), hidden_size, leaky_relu)
|
34 |
+
self.model.apply(init_weights)
|
35 |
+
|
36 |
+
def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
|
37 |
+
return self.model.forward(images)
|
38 |
+
|
39 |
+
def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
40 |
+
images, labels = batch
|
41 |
+
loss = self.forward_logits_loss(images, labels)[1]
|
42 |
+
self.log('loss', loss)
|
43 |
+
return loss
|
strhub/models/modules.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""Shared modules used by CRNN and TRBA"""
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class BidirectionalLSTM(nn.Module):
|
6 |
+
"""Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py"""
|
7 |
+
|
8 |
+
def __init__(self, input_size, hidden_size, output_size):
|
9 |
+
super().__init__()
|
10 |
+
self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
|
11 |
+
self.linear = nn.Linear(hidden_size * 2, output_size)
|
12 |
+
|
13 |
+
def forward(self, input):
|
14 |
+
"""
|
15 |
+
input : visual feature [batch_size x T x input_size], T = num_steps.
|
16 |
+
output : contextual feature [batch_size x T x output_size]
|
17 |
+
"""
|
18 |
+
recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
|
19 |
+
output = self.linear(recurrent) # batch_size x T x output_size
|
20 |
+
return output
|
strhub/models/parseq/__init__.py
ADDED
File without changes
|
strhub/models/parseq/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (148 Bytes). View file
|
|
strhub/models/parseq/__pycache__/modules.cpython-37.pyc
ADDED
Binary file (4.96 kB). View file
|
|
strhub/models/parseq/__pycache__/system.cpython-37.pyc
ADDED
Binary file (7.62 kB). View file
|
|
strhub/models/parseq/modules.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scene Text Recognition Model Hub
|
2 |
+
# Copyright 2022 Darwin Bautista
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import math
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn as nn, Tensor
|
21 |
+
from torch.nn import functional as F
|
22 |
+
from torch.nn.modules import transformer
|
23 |
+
|
24 |
+
from timm.models.vision_transformer import VisionTransformer, PatchEmbed
|
25 |
+
|
26 |
+
|
27 |
+
class DecoderLayer(nn.Module):
|
28 |
+
"""A Transformer decoder layer supporting two-stream attention (XLNet)
|
29 |
+
This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
|
30 |
+
|
31 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu',
|
32 |
+
layer_norm_eps=1e-5):
|
33 |
+
super().__init__()
|
34 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
|
35 |
+
self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
|
36 |
+
# Implementation of Feedforward model
|
37 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
38 |
+
self.dropout = nn.Dropout(dropout)
|
39 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
40 |
+
|
41 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
42 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
43 |
+
self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
44 |
+
self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
45 |
+
self.dropout1 = nn.Dropout(dropout)
|
46 |
+
self.dropout2 = nn.Dropout(dropout)
|
47 |
+
self.dropout3 = nn.Dropout(dropout)
|
48 |
+
|
49 |
+
self.activation = transformer._get_activation_fn(activation)
|
50 |
+
|
51 |
+
def __setstate__(self, state):
|
52 |
+
if 'activation' not in state:
|
53 |
+
state['activation'] = F.gelu
|
54 |
+
super().__setstate__(state)
|
55 |
+
|
56 |
+
def forward_stream(self, tgt: Tensor, tgt_norm: Tensor, tgt_kv: Tensor, memory: Tensor, tgt_mask: Optional[Tensor],
|
57 |
+
tgt_key_padding_mask: Optional[Tensor]):
|
58 |
+
"""Forward pass for a single stream (i.e. content or query)
|
59 |
+
tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
|
60 |
+
Both tgt_kv and memory are expected to be LayerNorm'd too.
|
61 |
+
memory is LayerNorm'd by ViT.
|
62 |
+
"""
|
63 |
+
tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask,
|
64 |
+
key_padding_mask=tgt_key_padding_mask)
|
65 |
+
tgt = tgt + self.dropout1(tgt2)
|
66 |
+
|
67 |
+
tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
|
68 |
+
tgt = tgt + self.dropout2(tgt2)
|
69 |
+
|
70 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
|
71 |
+
tgt = tgt + self.dropout3(tgt2)
|
72 |
+
return tgt, sa_weights, ca_weights
|
73 |
+
|
74 |
+
def forward(self, query, content, memory, query_mask: Optional[Tensor] = None, content_mask: Optional[Tensor] = None,
|
75 |
+
content_key_padding_mask: Optional[Tensor] = None, update_content: bool = True):
|
76 |
+
query_norm = self.norm_q(query)
|
77 |
+
content_norm = self.norm_c(content)
|
78 |
+
query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0]
|
79 |
+
if update_content:
|
80 |
+
content = self.forward_stream(content, content_norm, content_norm, memory, content_mask,
|
81 |
+
content_key_padding_mask)[0]
|
82 |
+
return query, content
|
83 |
+
|
84 |
+
|
85 |
+
class Decoder(nn.Module):
|
86 |
+
__constants__ = ['norm']
|
87 |
+
|
88 |
+
def __init__(self, decoder_layer, num_layers, norm):
|
89 |
+
super().__init__()
|
90 |
+
self.layers = transformer._get_clones(decoder_layer, num_layers)
|
91 |
+
self.num_layers = num_layers
|
92 |
+
self.norm = norm
|
93 |
+
|
94 |
+
def forward(self, query, content, memory, query_mask: Optional[Tensor] = None, content_mask: Optional[Tensor] = None,
|
95 |
+
content_key_padding_mask: Optional[Tensor] = None):
|
96 |
+
for i, mod in enumerate(self.layers):
|
97 |
+
last = i == len(self.layers) - 1
|
98 |
+
query, content = mod(query, content, memory, query_mask, content_mask, content_key_padding_mask,
|
99 |
+
update_content=not last)
|
100 |
+
query = self.norm(query)
|
101 |
+
return query
|
102 |
+
|
103 |
+
|
104 |
+
class Encoder(VisionTransformer):
|
105 |
+
|
106 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
|
107 |
+
qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed):
|
108 |
+
super().__init__(img_size, patch_size, in_chans, embed_dim=embed_dim, depth=depth, num_heads=num_heads,
|
109 |
+
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
|
110 |
+
drop_path_rate=drop_path_rate, embed_layer=embed_layer,
|
111 |
+
num_classes=0, global_pool='', class_token=False) # these disable the classifier head
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
# Return all tokens
|
115 |
+
return self.forward_features(x)
|
116 |
+
|
117 |
+
|
118 |
+
class TokenEmbedding(nn.Module):
|
119 |
+
|
120 |
+
def __init__(self, charset_size: int, embed_dim: int):
|
121 |
+
super().__init__()
|
122 |
+
self.embedding = nn.Embedding(charset_size, embed_dim)
|
123 |
+
self.embed_dim = embed_dim
|
124 |
+
|
125 |
+
def forward(self, tokens: torch.Tensor):
|
126 |
+
return math.sqrt(self.embed_dim) * self.embedding(tokens)
|