diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..ba139f28fcd686f0b24e7487ae72937c4cb7f095 --- /dev/null +++ b/app.py @@ -0,0 +1,39 @@ +import gradio as gr +import shutil + +import urllib.request +import sys +import os +import urllib.request +import zipfile + +sys.path.append(".") + +from model import prediction + +gr.close_all() + +https://storage.googleapis.com/models-gradio/products/products.zip + +urllib.request.urlretrieve("https://storage.googleapis.com/models-gradio/products/products.zip") + + +with zipfile.ZipFile("products.zip", 'r') as zip_ref: + zip_ref.extractall() + +def predict(img): + name_image = img.split("/")[-1] + + prediction_img, text = prediction(img) + + return str(text), prediction_img, + + +sample_images = ["dataset/" + name for name in os.listdir("dataset")] + + +gr.Interface(fn=predict, + inputs=[gr.Image(label="image à tester" ,type="filepath")], + outputs=[gr.Textbox(label="analyse"), gr.Image(label ="résultat") ], + css="footer {visibility: hidden} body}, .gradio-container {background-color: white}", + examples=sample_images).launch(server_name="0.0.0.0", share=False) \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000000000000000000000000000000000000..bf4c526f6f9452eac0cc2b6ecf44c0eebd2373f2 --- /dev/null +++ b/model.py @@ -0,0 +1,52 @@ + +from strhub.data.module import SceneTextDataModule +from strhub.models.utils import load_from_checkpoint +from post import filter_mask +import segmentation_models_pytorch as smp +import albumentations as albu +from torchvision import transforms +from PIL import Image +import torch +import cv2 + +model_recog = load_from_checkpoint("weights/parseq/last.ckpt").eval().to("cpu") +img_transform = SceneTextDataModule.get_transform(model_recog.hparams.img_size) + +model = torch.load('weights/best_model.pth').to("cpu") +model.eval() +model.float() + +SHAPE_X = 384 +SHAPE_Y = 384 + + +def prediction(image_path): + image = cv2.imread(image_path) + image_original = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + preprocessing_fn = smp.encoders.get_preprocessing_fn('resnet50') + transform = albu.Compose([ + albu.Lambda(image=preprocessing_fn), albu.Resize(SHAPE_X, SHAPE_Y) + ]) + + image_result = transform(image=image_original)["image"] + transform = transforms.ToTensor() + tensor = transform(image_result) + tensor = torch.unsqueeze(tensor, 0) + output = model.predict(tensor.float()) + + result, img_vis = filter_mask(output, image_original ) + + image = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(image) + image = img_transform(im_pil).unsqueeze(0).to("cpu") + + p = model_recog(image).softmax(-1) + pred, p = model_recog.tokenizer.decode(p) + print(f'{image_path}: {pred[0]}') + + + return img_vis, pred[0] + + + + \ No newline at end of file diff --git a/post.py b/post.py new file mode 100644 index 0000000000000000000000000000000000000000..7e86ff6a31a00a61ce599140c7165d91622d2a78 --- /dev/null +++ b/post.py @@ -0,0 +1,42 @@ +import cv2 +import numpy as np + +def filter_mask(output, image): + image_h, image_w = image.shape[:2] + + predict_mask = (output.squeeze().cpu().numpy().round()) + predict_mask = predict_mask.astype('uint8')*255 + predict_mask = cv2.resize( predict_mask, (image_w, image_h) ) + + ret, thresh = cv2.threshold(predict_mask, 127, 255, 0) + contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + points = contours[0] + + rect = cv2.minAreaRect(points) + box = cv2.boxPoints(rect) + box = np.int0(box) + + img_vis = cv2.drawContours(image.copy(),[box],0,(255,0,0), 6) + + (cX, cY), (w, h), angle = rect + + if w max_label_len: + continue + label = charset_adapter(label) + # We filter out samples which don't contain any supported characters + if not label: + continue + # Filter images that are too small. + if min_image_dim > 0: + img_key = f'image-{index:09d}'.encode() + buf = io.BytesIO(txn.get(img_key)) + w, h = Image.open(buf).size + if w < self.min_image_dim or h < self.min_image_dim: + continue + self.labels.append(label) + self.filtered_index_list.append(index) + return len(self.labels) + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + if self.unlabelled: + label = index + else: + label = self.labels[index] + index = self.filtered_index_list[index] + + img_key = f'image-{index:09d}'.encode() + with self.env.begin() as txn: + imgbuf = txn.get(img_key) + buf = io.BytesIO(imgbuf) + img = Image.open(buf).convert('RGB') + + if self.transform is not None: + img = self.transform(img) + + return img, label diff --git a/strhub/data/.ipynb_checkpoints/module-checkpoint.py b/strhub/data/.ipynb_checkpoints/module-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..8d8b2d3d63eaf0b6839f67cc4dca7c28898a658a --- /dev/null +++ b/strhub/data/.ipynb_checkpoints/module-checkpoint.py @@ -0,0 +1,107 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import PurePath +from typing import Optional, Callable, Sequence, Tuple + +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from torchvision import transforms as T + +from .dataset import build_tree_dataset, LmdbDataset + + +class SceneTextDataModule(pl.LightningDataModule): + TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80') + TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80') + TEST_NEW = ('ArT', 'COCOv1.4', 'Uber') + TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW)) + + def __init__(self, root_dir: str, train_dir: str, img_size: Sequence[int], max_label_length: int, + charset_train: str, charset_test: str, batch_size: int, num_workers: int, augment: bool, + remove_whitespace: bool = True, normalize_unicode: bool = True, + min_image_dim: int = 0, rotation: int = 0, collate_fn: Optional[Callable] = None): + super().__init__() + self.root_dir = root_dir + self.train_dir = train_dir + self.img_size = tuple(img_size) + self.max_label_length = max_label_length + self.charset_train = charset_train + self.charset_test = charset_test + self.batch_size = batch_size + self.num_workers = num_workers + self.augment = augment + self.remove_whitespace = remove_whitespace + self.normalize_unicode = normalize_unicode + self.min_image_dim = min_image_dim + self.rotation = rotation + self.collate_fn = collate_fn + self._train_dataset = None + self._val_dataset = None + + @staticmethod + def get_transform(img_size: Tuple[int], augment: bool = False, rotation: int = 0): + transforms = [] + if augment: + from .augment import rand_augment_transform + transforms.append(rand_augment_transform()) + if rotation: + transforms.append(lambda img: img.rotate(rotation, expand=True)) + transforms.extend([ + T.Resize(img_size, T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(0.5, 0.5) + ]) + return T.Compose(transforms) + + @property + def train_dataset(self): + if self._train_dataset is None: + transform = self.get_transform(self.img_size, self.augment) + root = PurePath(self.root_dir, 'train', self.train_dir) + self._train_dataset = build_tree_dataset(root, self.charset_train, self.max_label_length, + self.min_image_dim, self.remove_whitespace, self.normalize_unicode, + transform=transform) + return self._train_dataset + + @property + def val_dataset(self): + if self._val_dataset is None: + transform = self.get_transform(self.img_size) + root = PurePath(self.root_dir, 'val') + self._val_dataset = build_tree_dataset(root, self.charset_test, self.max_label_length, + self.min_image_dim, self.remove_whitespace, self.normalize_unicode, + transform=transform) + return self._val_dataset + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, + num_workers=self.num_workers, persistent_workers=self.num_workers > 0, + pin_memory=True, collate_fn=self.collate_fn) + + def val_dataloader(self): + return DataLoader(self.val_dataset, batch_size=self.batch_size, + num_workers=self.num_workers, persistent_workers=self.num_workers > 0, + pin_memory=True, collate_fn=self.collate_fn) + + def test_dataloaders(self, subset): + transform = self.get_transform(self.img_size, rotation=self.rotation) + root = PurePath(self.root_dir, 'test') + datasets = {s: LmdbDataset(str(root / s), self.charset_test, self.max_label_length, + self.min_image_dim, self.remove_whitespace, self.normalize_unicode, + transform=transform) for s in subset} + return {k: DataLoader(v, batch_size=self.batch_size, num_workers=self.num_workers, + pin_memory=True, collate_fn=self.collate_fn) + for k, v in datasets.items()} diff --git a/strhub/data/__init__.py b/strhub/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/strhub/data/__pycache__/__init__.cpython-37.pyc b/strhub/data/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..736e5eb047fe30ff9adae189e38e4baad97fbef3 Binary files /dev/null and b/strhub/data/__pycache__/__init__.cpython-37.pyc differ diff --git a/strhub/data/__pycache__/aa_overrides.cpython-37.pyc b/strhub/data/__pycache__/aa_overrides.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8de20d3ce1018c4c4f7134bf6e851bdafc5e10fc Binary files /dev/null and b/strhub/data/__pycache__/aa_overrides.cpython-37.pyc differ diff --git a/strhub/data/__pycache__/augment.cpython-37.pyc b/strhub/data/__pycache__/augment.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62c1b1dc542a6e677e878322cb4af62d54278051 Binary files /dev/null and b/strhub/data/__pycache__/augment.cpython-37.pyc differ diff --git a/strhub/data/__pycache__/dataset.cpython-37.pyc b/strhub/data/__pycache__/dataset.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..094036b1636f79e5a16bb57d8a29fb7c7e64300d Binary files /dev/null and b/strhub/data/__pycache__/dataset.cpython-37.pyc differ diff --git a/strhub/data/__pycache__/module.cpython-37.pyc b/strhub/data/__pycache__/module.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..326d84a6b91fec3dace386003fc8f9c6cad49a3e Binary files /dev/null and b/strhub/data/__pycache__/module.cpython-37.pyc differ diff --git a/strhub/data/__pycache__/utils.cpython-37.pyc b/strhub/data/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d585e9fd851d199cc6cf459daad3e1ba797ec1d Binary files /dev/null and b/strhub/data/__pycache__/utils.cpython-37.pyc differ diff --git a/strhub/data/aa_overrides.py b/strhub/data/aa_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..7dcba717676180d61bee82c0404b5d3b3a63d339 --- /dev/null +++ b/strhub/data/aa_overrides.py @@ -0,0 +1,46 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extends default ops to accept optional parameters.""" +from functools import partial + +from timm.data.auto_augment import _LEVEL_DENOM, _randomly_negate, LEVEL_TO_ARG, NAME_TO_OP, rotate + + +def rotate_expand(img, degrees, **kwargs): + """Rotate operation with expand=True to avoid cutting off the characters""" + kwargs['expand'] = True + return rotate(img, degrees, **kwargs) + + +def _level_to_arg(level, hparams, key, default): + magnitude = hparams.get(key, default) + level = (level / _LEVEL_DENOM) * magnitude + level = _randomly_negate(level) + return level, + + +def apply(): + # Overrides + NAME_TO_OP.update({ + 'Rotate': rotate_expand + }) + LEVEL_TO_ARG.update({ + 'Rotate': partial(_level_to_arg, key='rotate_deg', default=30.), + 'ShearX': partial(_level_to_arg, key='shear_x_pct', default=0.3), + 'ShearY': partial(_level_to_arg, key='shear_y_pct', default=0.3), + 'TranslateXRel': partial(_level_to_arg, key='translate_x_pct', default=0.45), + 'TranslateYRel': partial(_level_to_arg, key='translate_y_pct', default=0.45), + }) diff --git a/strhub/data/augment.py b/strhub/data/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c1fb5a5ee5cc307f056d9ea0c78e91514078e6 --- /dev/null +++ b/strhub/data/augment.py @@ -0,0 +1,111 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import imgaug.augmenters as iaa +import numpy as np +from PIL import ImageFilter, Image +from timm.data import auto_augment + +from strhub.data import aa_overrides + +aa_overrides.apply() + +_OP_CACHE = {} + + +def _get_op(key, factory): + try: + op = _OP_CACHE[key] + except KeyError: + op = factory() + _OP_CACHE[key] = op + return op + + +def _get_param(level, img, max_dim_factor, min_level=1): + max_level = max(min_level, max_dim_factor * max(img.size)) + return round(min(level, max_level)) + + +def gaussian_blur(img, radius, **__): + radius = _get_param(radius, img, 0.02) + key = 'gaussian_blur_' + str(radius) + op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius)) + return img.filter(op) + + +def motion_blur(img, k, **__): + k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values + key = 'motion_blur_' + str(k) + op = _get_op(key, lambda: iaa.MotionBlur(k)) + return Image.fromarray(op(image=np.asarray(img))) + + +def gaussian_noise(img, scale, **_): + scale = _get_param(scale, img, 0.25) | 1 # bin to odd values + key = 'gaussian_noise_' + str(scale) + op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale)) + return Image.fromarray(op(image=np.asarray(img))) + + +def poisson_noise(img, lam, **_): + lam = _get_param(lam, img, 0.2) | 1 # bin to odd values + key = 'poisson_noise_' + str(lam) + op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam)) + return Image.fromarray(op(image=np.asarray(img))) + + +def _level_to_arg(level, _hparams, max): + level = max * level / auto_augment._LEVEL_DENOM + return level, + + +_RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy() +_RAND_TRANSFORMS.remove('SharpnessIncreasing') # remove, interferes with *blur ops +_RAND_TRANSFORMS.extend([ + 'GaussianBlur', + # 'MotionBlur', + # 'GaussianNoise', + 'PoissonNoise' +]) +auto_augment.LEVEL_TO_ARG.update({ + 'GaussianBlur': partial(_level_to_arg, max=4), + 'MotionBlur': partial(_level_to_arg, max=20), + 'GaussianNoise': partial(_level_to_arg, max=0.1 * 255), + 'PoissonNoise': partial(_level_to_arg, max=40) +}) +auto_augment.NAME_TO_OP.update({ + 'GaussianBlur': gaussian_blur, + 'MotionBlur': motion_blur, + 'GaussianNoise': gaussian_noise, + 'PoissonNoise': poisson_noise +}) + + +def rand_augment_transform(magnitude=5, num_layers=3): + # These are tuned for magnitude=5, which means that effective magnitudes are half of these values. + hparams = { + 'rotate_deg': 30, + 'shear_x_pct': 0.9, + 'shear_y_pct': 0.2, + 'translate_x_pct': 0.10, + 'translate_y_pct': 0.30 + } + ra_ops = auto_augment.rand_augment_ops(magnitude, hparams, transforms=_RAND_TRANSFORMS) + # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice) + choice_weights = [1. / len(ra_ops) for _ in range(len(ra_ops))] + return auto_augment.RandAugment(ra_ops, num_layers, choice_weights) diff --git a/strhub/data/dataset.py b/strhub/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e1da774dad124346f1fecbdab4ab7885a2cd8cdd --- /dev/null +++ b/strhub/data/dataset.py @@ -0,0 +1,137 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import glob +import io +import logging +import unicodedata +from pathlib import Path, PurePath +from typing import Callable, Optional, Union + +import lmdb +from PIL import Image +from torch.utils.data import Dataset, ConcatDataset + +from strhub.data.utils import CharsetAdapter + +log = logging.getLogger(__name__) + + +def build_tree_dataset(root: Union[PurePath, str], *args, **kwargs): + try: + kwargs.pop('root') # prevent 'root' from being passed via kwargs + except KeyError: + pass + root = Path(root).absolute() + log.info(f'dataset root:\t{root}') + datasets = [] + for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True): + mdb = Path(mdb) + ds_name = str(mdb.parent.relative_to(root)) + ds_root = str(mdb.parent.absolute()) + dataset = LmdbDataset(ds_root, *args, **kwargs) + log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}') + datasets.append(dataset) + return ConcatDataset(datasets) + + +class LmdbDataset(Dataset): + """Dataset interface to an LMDB database. + + It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned + as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset. + Labels are transformed according to the charset. + """ + + def __init__(self, root: str, charset: str, max_label_len: int, min_image_dim: int = 0, + remove_whitespace: bool = True, normalize_unicode: bool = True, + unlabelled: bool = False, transform: Optional[Callable] = None): + self._env = None + self.root = root + self.unlabelled = unlabelled + self.transform = transform + self.labels = [] + self.filtered_index_list = [] + self.num_samples = self._preprocess_labels(charset, remove_whitespace, normalize_unicode, + max_label_len, min_image_dim) + + def __del__(self): + if self._env is not None: + self._env.close() + self._env = None + + def _create_env(self): + return lmdb.open(self.root, max_readers=1, readonly=True, create=False, + readahead=False, meminit=False, lock=False) + + @property + def env(self): + if self._env is None: + self._env = self._create_env() + return self._env + + def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim): + charset_adapter = CharsetAdapter(charset) + with self._create_env() as env, env.begin() as txn: + num_samples = int(txn.get('num-samples'.encode())) + if self.unlabelled: + return num_samples + for index in range(num_samples): + index += 1 # lmdb starts with 1 + label_key = f'label-{index:09d}'.encode() + label = txn.get(label_key).decode() + # Normally, whitespace is removed from the labels. + if remove_whitespace: + label = ''.join(label.split()) + # Normalize unicode composites (if any) and convert to compatible ASCII characters + if normalize_unicode: + label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode() + # Filter by length before removing unsupported characters. The original label might be too long. + if len(label) > max_label_len: + continue + label = charset_adapter(label) + # We filter out samples which don't contain any supported characters + if not label: + continue + # Filter images that are too small. + if min_image_dim > 0: + img_key = f'image-{index:09d}'.encode() + buf = io.BytesIO(txn.get(img_key)) + w, h = Image.open(buf).size + if w < self.min_image_dim or h < self.min_image_dim: + continue + self.labels.append(label) + self.filtered_index_list.append(index) + return len(self.labels) + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + if self.unlabelled: + label = index + else: + label = self.labels[index] + index = self.filtered_index_list[index] + + img_key = f'image-{index:09d}'.encode() + with self.env.begin() as txn: + imgbuf = txn.get(img_key) + buf = io.BytesIO(imgbuf) + img = Image.open(buf).convert('RGB') + + if self.transform is not None: + img = self.transform(img) + + return img, label diff --git a/strhub/data/module.py b/strhub/data/module.py new file mode 100644 index 0000000000000000000000000000000000000000..8d8b2d3d63eaf0b6839f67cc4dca7c28898a658a --- /dev/null +++ b/strhub/data/module.py @@ -0,0 +1,107 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import PurePath +from typing import Optional, Callable, Sequence, Tuple + +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from torchvision import transforms as T + +from .dataset import build_tree_dataset, LmdbDataset + + +class SceneTextDataModule(pl.LightningDataModule): + TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80') + TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80') + TEST_NEW = ('ArT', 'COCOv1.4', 'Uber') + TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW)) + + def __init__(self, root_dir: str, train_dir: str, img_size: Sequence[int], max_label_length: int, + charset_train: str, charset_test: str, batch_size: int, num_workers: int, augment: bool, + remove_whitespace: bool = True, normalize_unicode: bool = True, + min_image_dim: int = 0, rotation: int = 0, collate_fn: Optional[Callable] = None): + super().__init__() + self.root_dir = root_dir + self.train_dir = train_dir + self.img_size = tuple(img_size) + self.max_label_length = max_label_length + self.charset_train = charset_train + self.charset_test = charset_test + self.batch_size = batch_size + self.num_workers = num_workers + self.augment = augment + self.remove_whitespace = remove_whitespace + self.normalize_unicode = normalize_unicode + self.min_image_dim = min_image_dim + self.rotation = rotation + self.collate_fn = collate_fn + self._train_dataset = None + self._val_dataset = None + + @staticmethod + def get_transform(img_size: Tuple[int], augment: bool = False, rotation: int = 0): + transforms = [] + if augment: + from .augment import rand_augment_transform + transforms.append(rand_augment_transform()) + if rotation: + transforms.append(lambda img: img.rotate(rotation, expand=True)) + transforms.extend([ + T.Resize(img_size, T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(0.5, 0.5) + ]) + return T.Compose(transforms) + + @property + def train_dataset(self): + if self._train_dataset is None: + transform = self.get_transform(self.img_size, self.augment) + root = PurePath(self.root_dir, 'train', self.train_dir) + self._train_dataset = build_tree_dataset(root, self.charset_train, self.max_label_length, + self.min_image_dim, self.remove_whitespace, self.normalize_unicode, + transform=transform) + return self._train_dataset + + @property + def val_dataset(self): + if self._val_dataset is None: + transform = self.get_transform(self.img_size) + root = PurePath(self.root_dir, 'val') + self._val_dataset = build_tree_dataset(root, self.charset_test, self.max_label_length, + self.min_image_dim, self.remove_whitespace, self.normalize_unicode, + transform=transform) + return self._val_dataset + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, + num_workers=self.num_workers, persistent_workers=self.num_workers > 0, + pin_memory=True, collate_fn=self.collate_fn) + + def val_dataloader(self): + return DataLoader(self.val_dataset, batch_size=self.batch_size, + num_workers=self.num_workers, persistent_workers=self.num_workers > 0, + pin_memory=True, collate_fn=self.collate_fn) + + def test_dataloaders(self, subset): + transform = self.get_transform(self.img_size, rotation=self.rotation) + root = PurePath(self.root_dir, 'test') + datasets = {s: LmdbDataset(str(root / s), self.charset_test, self.max_label_length, + self.min_image_dim, self.remove_whitespace, self.normalize_unicode, + transform=transform) for s in subset} + return {k: DataLoader(v, batch_size=self.batch_size, num_workers=self.num_workers, + pin_memory=True, collate_fn=self.collate_fn) + for k, v in datasets.items()} diff --git a/strhub/data/utils.py b/strhub/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..52a5b56fb5ccacf07285b786a33a3e8102a99028 --- /dev/null +++ b/strhub/data/utils.py @@ -0,0 +1,148 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from abc import ABC, abstractmethod +from itertools import groupby +from typing import List, Optional, Tuple + +import torch +from torch import Tensor +from torch.nn.utils.rnn import pad_sequence + + +class CharsetAdapter: + """Transforms labels according to the target charset.""" + + def __init__(self, target_charset) -> None: + super().__init__() + self.lowercase_only = target_charset == target_charset.lower() + self.uppercase_only = target_charset == target_charset.upper() + self.unsupported = f'[^{re.escape(target_charset)}]' + + def __call__(self, label): + if self.lowercase_only: + label = label.lower() + elif self.uppercase_only: + label = label.upper() + # Remove unsupported characters + label = re.sub(self.unsupported, '', label) + return label + + +class BaseTokenizer(ABC): + + def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None: + self._itos = specials_first + tuple(charset) + specials_last + self._stoi = {s: i for i, s in enumerate(self._itos)} + + def __len__(self): + return len(self._itos) + + def _tok2ids(self, tokens: str) -> List[int]: + return [self._stoi[s] for s in tokens] + + def _ids2tok(self, token_ids: List[int], join: bool = True) -> str: + tokens = [self._itos[i] for i in token_ids] + return ''.join(tokens) if join else tokens + + @abstractmethod + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + """Encode a batch of labels to a representation suitable for the model. + + Args: + labels: List of labels. Each can be of arbitrary length. + device: Create tensor on this device. + + Returns: + Batched tensor representation padded to the max label length. Shape: N, L + """ + raise NotImplementedError + + @abstractmethod + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + """Internal method which performs the necessary filtering prior to decoding.""" + raise NotImplementedError + + def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]: + """Decode a batch of token distributions. + + Args: + token_dists: softmax probabilities over the token distribution. Shape: N, L, C + raw: return unprocessed labels (will return list of list of strings) + + Returns: + list of string labels (arbitrary length) and + their corresponding sequence probabilities as a list of Tensors + """ + batch_tokens = [] + batch_probs = [] + for dist in token_dists: + probs, ids = dist.max(-1) # greedy selection + if not raw: + probs, ids = self._filter(probs, ids) + tokens = self._ids2tok(ids, not raw) + batch_tokens.append(tokens) + batch_probs.append(probs) + return batch_tokens, batch_probs + + +class Tokenizer(BaseTokenizer): + BOS = '[B]' + EOS = '[E]' + PAD = '[P]' + + def __init__(self, charset: str) -> None: + specials_first = (self.EOS,) + specials_last = (self.BOS, self.PAD) + super().__init__(charset, specials_first, specials_last) + self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last] + + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device) + for y in labels] + return pad_sequence(batch, batch_first=True, padding_value=self.pad_id) + + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + ids = ids.tolist() + try: + eos_idx = ids.index(self.eos_id) + except ValueError: + eos_idx = len(ids) # Nothing to truncate. + # Truncate after EOS + ids = ids[:eos_idx] + probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists) + return probs, ids + + +class CTCTokenizer(BaseTokenizer): + BLANK = '[B]' + + def __init__(self, charset: str) -> None: + # BLANK uses index == 0 by default + super().__init__(charset, specials_first=(self.BLANK,)) + self.blank_id = self._stoi[self.BLANK] + + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + # We use a padded representation since we don't want to use CUDNN's CTC implementation + batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels] + return pad_sequence(batch, batch_first=True, padding_value=self.blank_id) + + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + # Best path decoding: + ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens + ids = [x for x in ids if x != self.blank_id] # Remove BLANKs + # `probs` is just pass-through since all positions are considered part of the path + return probs, ids diff --git a/strhub/models/.ipynb_checkpoints/base-checkpoint.py b/strhub/models/.ipynb_checkpoints/base-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..a9fafe450f615f027503522ff5c81443fe29aa4a --- /dev/null +++ b/strhub/models/.ipynb_checkpoints/base-checkpoint.py @@ -0,0 +1,202 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Tuple, List + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from nltk import edit_distance +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from timm.optim import create_optimizer_v2 +from torch import Tensor +from torch.optim import Optimizer +from torch.optim.lr_scheduler import OneCycleLR + +from strhub.data.utils import CharsetAdapter, CTCTokenizer, Tokenizer, BaseTokenizer + + +@dataclass +class BatchResult: + num_samples: int + correct: int + ned: float + confidence: float + label_length: int + loss: Tensor + loss_numel: int + + +class BaseSystem(pl.LightningModule, ABC): + + def __init__(self, tokenizer: BaseTokenizer, charset_test: str, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None: + super().__init__() + self.tokenizer = tokenizer + self.charset_adapter = CharsetAdapter(charset_test) + self.batch_size = batch_size + self.lr = lr + self.warmup_pct = warmup_pct + self.weight_decay = weight_decay + + @abstractmethod + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + """Inference + + Args: + images: Batch of images. Shape: N, Ch, H, W + max_length: Max sequence length of the output. If None, will use default. + + Returns: + logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials) + """ + raise NotImplementedError + + @abstractmethod + def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]: + """Like forward(), but also computes the loss (calls forward() internally). + + Args: + images: Batch of images. Shape: N, Ch, H, W + labels: Text labels of the images + + Returns: + logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials) + loss: mean loss for the batch + loss_numel: number of elements the loss was calculated from + """ + raise NotImplementedError + + def configure_optimizers(self): + agb = self.trainer.accumulate_grad_batches + # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP. + lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256. + lr = lr_scale * self.lr + optim = create_optimizer_v2(self, 'adamw', lr, self.weight_decay) + sched = OneCycleLR(optim, lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, + cycle_momentum=False) + return {'optimizer': optim, 'lr_scheduler': {'scheduler': sched, 'interval': 'step'}} + + def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): + optimizer.zero_grad(set_to_none=True) + + def _eval_step(self, batch, validation: bool) -> Optional[STEP_OUTPUT]: + images, labels = batch + + correct = 0 + total = 0 + ned = 0 + confidence = 0 + label_length = 0 + if validation: + logits, loss, loss_numel = self.forward_logits_loss(images, labels) + else: + # At test-time, we shouldn't specify a max_label_length because the test-time charset used + # might be different from the train-time charset. max_label_length in eval_logits_loss() is computed + # based on the transformed label, which could be wrong if the actual gt label contains characters existing + # in the train-time charset but not in the test-time charset. For example, "aishahaleyes.blogspot.com" + # is exactly 25 characters, but if processed by CharsetAdapter for the 36-char set, it becomes 23 characters + # long only, which sets max_label_length = 23. This will cause the model prediction to be truncated. + logits = self.forward(images) + loss = loss_numel = None # Only used for validation; not needed at test-time. + + probs = logits.softmax(-1) + preds, probs = self.tokenizer.decode(probs) + for pred, prob, gt in zip(preds, probs, labels): + confidence += prob.prod().item() + pred = self.charset_adapter(pred) + # Follow ICDAR 2019 definition of N.E.D. + ned += edit_distance(pred, gt) / max(len(pred), len(gt)) + if pred == gt: + correct += 1 + total += 1 + label_length += len(pred) + return dict(output=BatchResult(total, correct, ned, confidence, label_length, loss, loss_numel)) + + @staticmethod + def _aggregate_results(outputs: EPOCH_OUTPUT) -> Tuple[float, float, float]: + if not outputs: + return 0., 0., 0. + total_loss = 0 + total_loss_numel = 0 + total_n_correct = 0 + total_norm_ED = 0 + total_size = 0 + for result in outputs: + result = result['output'] + total_loss += result.loss_numel * result.loss + total_loss_numel += result.loss_numel + total_n_correct += result.correct + total_norm_ED += result.ned + total_size += result.num_samples + acc = total_n_correct / total_size + ned = (1 - total_norm_ED / total_size) + loss = total_loss / total_loss_numel + return acc, ned, loss + + def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]: + return self._eval_step(batch, True) + + def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + acc, ned, loss = self._aggregate_results(outputs) + self.log('val_accuracy', 100 * acc, sync_dist=True) + self.log('val_NED', 100 * ned, sync_dist=True) + self.log('val_loss', loss, sync_dist=True) + self.log('hp_metric', acc, sync_dist=True) + + def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]: + return self._eval_step(batch, False) + + +class CrossEntropySystem(BaseSystem): + + def __init__(self, charset_train: str, charset_test: str, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None: + tokenizer = Tokenizer(charset_train) + super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.bos_id = tokenizer.bos_id + self.eos_id = tokenizer.eos_id + self.pad_id = tokenizer.pad_id + + def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]: + targets = self.tokenizer.encode(labels, self.device) + targets = targets[:, 1:] # Discard + max_len = targets.shape[1] - 1 # exclude from count + logits = self.forward(images, max_len) + loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id) + loss_numel = (targets != self.pad_id).sum() + return logits, loss, loss_numel + + +class CTCSystem(BaseSystem): + + def __init__(self, charset_train: str, charset_test: str, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None: + tokenizer = CTCTokenizer(charset_train) + super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.blank_id = tokenizer.blank_id + + def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]: + targets = self.tokenizer.encode(labels, self.device) + logits = self.forward(images) + log_probs = logits.log_softmax(-1).transpose(0, 1) # swap batch and seq. dims + T, N, _ = log_probs.shape + input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long, device=self.device) + target_lengths = torch.as_tensor(list(map(len, labels)), dtype=torch.long, device=self.device) + loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_id, zero_infinity=True) + return logits, loss, N diff --git a/strhub/models/.ipynb_checkpoints/modules-checkpoint.py b/strhub/models/.ipynb_checkpoints/modules-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..a89d05f6afd67437f3cfa8aff6d2d8b12df3fafa --- /dev/null +++ b/strhub/models/.ipynb_checkpoints/modules-checkpoint.py @@ -0,0 +1,20 @@ +r"""Shared modules used by CRNN and TRBA""" +from torch import nn + + +class BidirectionalLSTM(nn.Module): + """Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py""" + + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) + self.linear = nn.Linear(hidden_size * 2, output_size) + + def forward(self, input): + """ + input : visual feature [batch_size x T x input_size], T = num_steps. + output : contextual feature [batch_size x T x output_size] + """ + recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) + output = self.linear(recurrent) # batch_size x T x output_size + return output diff --git a/strhub/models/.ipynb_checkpoints/utils-checkpoint.py b/strhub/models/.ipynb_checkpoints/utils-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..f53ddca31f10366840d00b7f82a44471dce089df --- /dev/null +++ b/strhub/models/.ipynb_checkpoints/utils-checkpoint.py @@ -0,0 +1,123 @@ +from pathlib import PurePath +from typing import Sequence + +import torch +from torch import nn + +import yaml + + +class InvalidModelError(RuntimeError): + """Exception raised for any model-related error (creation, loading)""" + + +_WEIGHTS_URL = { + 'parseq-tiny': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq_tiny-e7a21b54.pt', + 'parseq': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq-bb5792a6.pt', + 'abinet': 'https://github.com/baudm/parseq/releases/download/v1.0.0/abinet-1d1e373e.pt', + 'trba': 'https://github.com/baudm/parseq/releases/download/v1.0.0/trba-cfaed284.pt', + 'vitstr': 'https://github.com/baudm/parseq/releases/download/v1.0.0/vitstr-26d0fcf4.pt', + 'crnn': 'https://github.com/baudm/parseq/releases/download/v1.0.0/crnn-679d0e31.pt', +} + + +def _get_config(experiment: str, **kwargs): + """Emulates hydra config resolution""" + root = PurePath(__file__).parents[2] + with open(root / 'configs/main.yaml', 'r') as f: + config = yaml.load(f, yaml.Loader)['model'] + with open(root / f'configs/charset/94_full.yaml', 'r') as f: + config.update(yaml.load(f, yaml.Loader)['model']) + with open(root / f'configs/experiment/{experiment}.yaml', 'r') as f: + exp = yaml.load(f, yaml.Loader) + # Apply base model config + model = exp['defaults'][0]['override /model'] + with open(root / f'configs/model/{model}.yaml', 'r') as f: + config.update(yaml.load(f, yaml.Loader)) + # Apply experiment config + if 'model' in exp: + config.update(exp['model']) + config.update(kwargs) + # Workaround for now: manually cast the lr to the correct type. + config['lr'] = float(config['lr']) + return config + + +def _get_model_class(key): + if 'abinet' in key: + from .abinet.system import ABINet as ModelClass + elif 'crnn' in key: + from .crnn.system import CRNN as ModelClass + elif 'parseq' in key: + from .parseq.system import PARSeq as ModelClass + elif 'trba' in key: + from .trba.system import TRBA as ModelClass + elif 'trbc' in key: + from .trba.system import TRBC as ModelClass + elif 'vitstr' in key: + from .vitstr.system import ViTSTR as ModelClass + else: + raise InvalidModelError("Unable to find model class for '{}'".format(key)) + return ModelClass + + +def get_pretrained_weights(experiment): + try: + url = _WEIGHTS_URL[experiment] + except KeyError: + raise InvalidModelError("No pretrained weights found for '{}'".format(experiment)) from None + return torch.hub.load_state_dict_from_url(url=url, map_location='cpu', check_hash=True) + + +def create_model(experiment: str, pretrained: bool = False, **kwargs): + try: + config = _get_config(experiment, **kwargs) + except FileNotFoundError: + raise InvalidModelError("No configuration found for '{}'".format(experiment)) from None + ModelClass = _get_model_class(experiment) + model = ModelClass(**config) + if pretrained: + model.load_state_dict(get_pretrained_weights(experiment)) + return model + + +def load_from_checkpoint(checkpoint_path: str, **kwargs): + if checkpoint_path.startswith('pretrained='): + model_id = checkpoint_path.split('=', maxsplit=1)[1] + model = create_model(model_id, True, **kwargs) + else: + ModelClass = _get_model_class(checkpoint_path) + model = ModelClass.load_from_checkpoint(checkpoint_path, **kwargs) + return model + + +def parse_model_args(args): + kwargs = {} + arg_types = {t.__name__: t for t in [int, float, str]} + arg_types['bool'] = lambda v: v.lower() == 'true' # special handling for bool + for arg in args: + name, value = arg.split('=', maxsplit=1) + name, arg_type = name.split(':', maxsplit=1) + kwargs[name] = arg_types[arg_type](value) + return kwargs + + +def init_weights(module: nn.Module, name: str = '', exclude: Sequence[str] = ()): + """Initialize the weights using the typical initialization schemes used in SOTA models.""" + if any(map(name.startswith, exclude)): + return + if isinstance(module, nn.Linear): + nn.init.trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.trunc_normal_(module.weight, std=.02) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) diff --git a/strhub/models/__init__.py b/strhub/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/strhub/models/__pycache__/__init__.cpython-37.pyc b/strhub/models/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4250ce963600ef76bf74eedcd1b36b514aec2e5a Binary files /dev/null and b/strhub/models/__pycache__/__init__.cpython-37.pyc differ diff --git a/strhub/models/__pycache__/base.cpython-37.pyc b/strhub/models/__pycache__/base.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80ae9acbb7f79de2aea1903fe41e878e2fa977e0 Binary files /dev/null and b/strhub/models/__pycache__/base.cpython-37.pyc differ diff --git a/strhub/models/__pycache__/utils.cpython-37.pyc b/strhub/models/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b684c8be96089b532a121e4147d7e0d5e9a20a3f Binary files /dev/null and b/strhub/models/__pycache__/utils.cpython-37.pyc differ diff --git a/strhub/models/abinet/LICENSE b/strhub/models/abinet/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2f1d4adb4889b2719f13ed6edf56aed10246a516 --- /dev/null +++ b/strhub/models/abinet/LICENSE @@ -0,0 +1,25 @@ +ABINet for non-commercial purposes + +Copyright (c) 2021, USTC +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/strhub/models/abinet/__init__.py b/strhub/models/abinet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..604811036fda52d8485eecfebd4ffeb7f7176042 --- /dev/null +++ b/strhub/models/abinet/__init__.py @@ -0,0 +1,13 @@ +r""" +Fang, Shancheng, Hongtao, Xie, Yuxin, Wang, Zhendong, Mao, and Yongdong, Zhang. +"Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition." . +In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 7098-7107).2021. + +https://arxiv.org/abs/2103.06495 + +All source files, except `system.py`, are based on the implementation listed below, +and hence are released under the license of the original. + +Source: https://github.com/FangShancheng/ABINet +License: 2-clause BSD License (see included LICENSE file) +""" diff --git a/strhub/models/abinet/attention.py b/strhub/models/abinet/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..cc8fba0638e7444fdffe964f72d0566c1a5bb818 --- /dev/null +++ b/strhub/models/abinet/attention.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn + +from .transformer import PositionalEncoding + + +class Attention(nn.Module): + def __init__(self, in_channels=512, max_length=25, n_feature=256): + super().__init__() + self.max_length = max_length + + self.f0_embedding = nn.Embedding(max_length, in_channels) + self.w0 = nn.Linear(max_length, n_feature) + self.wv = nn.Linear(in_channels, in_channels) + self.we = nn.Linear(in_channels, max_length) + + self.active = nn.Tanh() + self.softmax = nn.Softmax(dim=2) + + def forward(self, enc_output): + enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2) + reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device) + reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S) + reading_order_embed = self.f0_embedding(reading_order) # b,25,512 + + t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256 + t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512 + + attn = self.we(t) # b,256,25 + attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256 + g_output = torch.bmm(attn, enc_output) # b,25,512 + return g_output, attn.view(*attn.shape[:2], 8, 32) + + +def encoder_layer(in_c, out_c, k=3, s=2, p=1): + return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p), + nn.BatchNorm2d(out_c), + nn.ReLU(True)) + + +def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None): + align_corners = None if mode == 'nearest' else True + return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor, + mode=mode, align_corners=align_corners), + nn.Conv2d(in_c, out_c, k, s, p), + nn.BatchNorm2d(out_c), + nn.ReLU(True)) + + +class PositionAttention(nn.Module): + def __init__(self, max_length, in_channels=512, num_channels=64, + h=8, w=32, mode='nearest', **kwargs): + super().__init__() + self.max_length = max_length + self.k_encoder = nn.Sequential( + encoder_layer(in_channels, num_channels, s=(1, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)) + ) + self.k_decoder = nn.Sequential( + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, in_channels, size=(h, w), mode=mode) + ) + + self.pos_encoder = PositionalEncoding(in_channels, dropout=0., max_len=max_length) + self.project = nn.Linear(in_channels, in_channels) + + def forward(self, x): + N, E, H, W = x.size() + k, v = x, x # (N, E, H, W) + + # calculate key vector + features = [] + for i in range(0, len(self.k_encoder)): + k = self.k_encoder[i](k) + features.append(k) + for i in range(0, len(self.k_decoder) - 1): + k = self.k_decoder[i](k) + k = k + features[len(self.k_decoder) - 2 - i] + k = self.k_decoder[-1](k) + + # calculate query vector + # TODO q=f(q,k) + zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E) + q = self.pos_encoder(zeros) # (T, N, E) + q = q.permute(1, 0, 2) # (N, T, E) + q = self.project(q) # (N, T, E) + + # calculate attention + attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W)) + attn_scores = attn_scores / (E ** 0.5) + attn_scores = torch.softmax(attn_scores, dim=-1) + + v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E) + attn_vecs = torch.bmm(attn_scores, v) # (N, T, E) + + return attn_vecs, attn_scores.view(N, -1, H, W) diff --git a/strhub/models/abinet/backbone.py b/strhub/models/abinet/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..debcabd7f115db0e698a55175a01a0ff0131e10f --- /dev/null +++ b/strhub/models/abinet/backbone.py @@ -0,0 +1,24 @@ +import torch.nn as nn +from torch.nn import TransformerEncoderLayer, TransformerEncoder + +from .resnet import resnet45 +from .transformer import PositionalEncoding + + +class ResTranformer(nn.Module): + def __init__(self, d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', backbone_ln=2): + super().__init__() + self.resnet = resnet45() + self.pos_encoder = PositionalEncoding(d_model, max_len=8 * 32) + encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead, + dim_feedforward=d_inner, dropout=dropout, activation=activation) + self.transformer = TransformerEncoder(encoder_layer, backbone_ln) + + def forward(self, images): + feature = self.resnet(images) + n, c, h, w = feature.shape + feature = feature.view(n, c, -1).permute(2, 0, 1) + feature = self.pos_encoder(feature) + feature = self.transformer(feature) + feature = feature.permute(1, 2, 0).view(n, c, h, w) + return feature diff --git a/strhub/models/abinet/model.py b/strhub/models/abinet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..cc0cd143d324822c57b897b6e5749024d857fd30 --- /dev/null +++ b/strhub/models/abinet/model.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn + + +class Model(nn.Module): + + def __init__(self, dataset_max_length: int, null_label: int): + super().__init__() + self.max_length = dataset_max_length + 1 # additional stop token + self.null_label = null_label + + def _get_length(self, logit, dim=-1): + """ Greed decoder to obtain length from logit""" + out = (logit.argmax(dim=-1) == self.null_label) + abn = out.any(dim) + out = ((out.cumsum(dim) == 1) & out).max(dim)[1] + out = out + 1 # additional end token + out = torch.where(abn, out, out.new_tensor(logit.shape[1], device=out.device)) + return out + + @staticmethod + def _get_padding_mask(length, max_length): + length = length.unsqueeze(-1) + grid = torch.arange(0, max_length, device=length.device).unsqueeze(0) + return grid >= length + + @staticmethod + def _get_location_mask(sz, device=None): + mask = torch.eye(sz, device=device) + mask = mask.float().masked_fill(mask == 1, float('-inf')) + return mask diff --git a/strhub/models/abinet/model_abinet_iter.py b/strhub/models/abinet/model_abinet_iter.py new file mode 100644 index 0000000000000000000000000000000000000000..1a8523ff6431f991037d56dc8dd72ae67c7bf242 --- /dev/null +++ b/strhub/models/abinet/model_abinet_iter.py @@ -0,0 +1,39 @@ +import torch +from torch import nn + +from .model_alignment import BaseAlignment +from .model_language import BCNLanguage +from .model_vision import BaseVision + + +class ABINetIterModel(nn.Module): + def __init__(self, dataset_max_length, null_label, num_classes, iter_size=1, + d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', + v_loss_weight=1., v_attention='position', v_attention_mode='nearest', + v_backbone='transformer', v_num_layers=2, + l_loss_weight=1., l_num_layers=4, l_detach=True, l_use_self_attn=False, + a_loss_weight=1.): + super().__init__() + self.iter_size = iter_size + self.vision = BaseVision(dataset_max_length, null_label, num_classes, v_attention, v_attention_mode, + v_loss_weight, d_model, nhead, d_inner, dropout, activation, v_backbone, v_num_layers) + self.language = BCNLanguage(dataset_max_length, null_label, num_classes, d_model, nhead, d_inner, dropout, + activation, l_num_layers, l_detach, l_use_self_attn, l_loss_weight) + self.alignment = BaseAlignment(dataset_max_length, null_label, num_classes, d_model, a_loss_weight) + + def forward(self, images): + v_res = self.vision(images) + a_res = v_res + all_l_res, all_a_res = [], [] + for _ in range(self.iter_size): + tokens = torch.softmax(a_res['logits'], dim=-1) + lengths = a_res['pt_lengths'] + lengths.clamp_(2, self.language.max_length) # TODO:move to langauge model + l_res = self.language(tokens, lengths) + all_l_res.append(l_res) + a_res = self.alignment(l_res['feature'], v_res['feature']) + all_a_res.append(a_res) + if self.training: + return all_a_res, all_l_res, v_res + else: + return a_res, all_l_res[-1], v_res diff --git a/strhub/models/abinet/model_alignment.py b/strhub/models/abinet/model_alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..9ccfa95e65dbd7176c8bcee693bb0bcb8ad13c69 --- /dev/null +++ b/strhub/models/abinet/model_alignment.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn + +from .model import Model + + +class BaseAlignment(Model): + def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, loss_weight=1.0): + super().__init__(dataset_max_length, null_label) + self.loss_weight = loss_weight + self.w_att = nn.Linear(2 * d_model, d_model) + self.cls = nn.Linear(d_model, num_classes) + + def forward(self, l_feature, v_feature): + """ + Args: + l_feature: (N, T, E) where T is length, N is batch size and d is dim of model + v_feature: (N, T, E) shape the same as l_feature + """ + f = torch.cat((l_feature, v_feature), dim=2) + f_att = torch.sigmoid(self.w_att(f)) + output = f_att * v_feature + (1 - f_att) * l_feature + + logits = self.cls(output) # (N, T, C) + pt_lengths = self._get_length(logits) + + return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight': self.loss_weight, + 'name': 'alignment'} diff --git a/strhub/models/abinet/model_language.py b/strhub/models/abinet/model_language.py new file mode 100644 index 0000000000000000000000000000000000000000..659d446578915bd1cab945554be749b4f1b0dff3 --- /dev/null +++ b/strhub/models/abinet/model_language.py @@ -0,0 +1,50 @@ +import torch.nn as nn +from torch.nn import TransformerDecoder + +from .model import Model +from .transformer import PositionalEncoding, TransformerDecoderLayer + + +class BCNLanguage(Model): + def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, nhead=8, d_inner=2048, dropout=0.1, + activation='relu', num_layers=4, detach=True, use_self_attn=False, loss_weight=1.0, + global_debug=False): + super().__init__(dataset_max_length, null_label) + self.detach = detach + self.loss_weight = loss_weight + self.proj = nn.Linear(num_classes, d_model, False) + self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length) + self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length) + decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout, + activation, self_attn=use_self_attn, debug=global_debug) + self.model = TransformerDecoder(decoder_layer, num_layers) + self.cls = nn.Linear(d_model, num_classes) + + def forward(self, tokens, lengths): + """ + Args: + tokens: (N, T, C) where T is length, N is batch size and C is classes number + lengths: (N,) + """ + if self.detach: + tokens = tokens.detach() + embed = self.proj(tokens) # (N, T, E) + embed = embed.permute(1, 0, 2) # (T, N, E) + embed = self.token_encoder(embed) # (T, N, E) + padding_mask = self._get_padding_mask(lengths, self.max_length) + + zeros = embed.new_zeros(*embed.shape) + qeury = self.pos_encoder(zeros) + location_mask = self._get_location_mask(self.max_length, tokens.device) + output = self.model(qeury, embed, + tgt_key_padding_mask=padding_mask, + memory_mask=location_mask, + memory_key_padding_mask=padding_mask) # (T, N, E) + output = output.permute(1, 0, 2) # (N, T, E) + + logits = self.cls(output) # (N, T, C) + pt_lengths = self._get_length(logits) + + res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths, + 'loss_weight': self.loss_weight, 'name': 'language'} + return res diff --git a/strhub/models/abinet/model_vision.py b/strhub/models/abinet/model_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..bddb7d5f237854b81c388090e2e20fc26632c431 --- /dev/null +++ b/strhub/models/abinet/model_vision.py @@ -0,0 +1,45 @@ +from torch import nn + +from .attention import PositionAttention, Attention +from .backbone import ResTranformer +from .model import Model +from .resnet import resnet45 + + +class BaseVision(Model): + def __init__(self, dataset_max_length, null_label, num_classes, + attention='position', attention_mode='nearest', loss_weight=1.0, + d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', + backbone='transformer', backbone_ln=2): + super().__init__(dataset_max_length, null_label) + self.loss_weight = loss_weight + self.out_channels = d_model + + if backbone == 'transformer': + self.backbone = ResTranformer(d_model, nhead, d_inner, dropout, activation, backbone_ln) + else: + self.backbone = resnet45() + + if attention == 'position': + self.attention = PositionAttention( + max_length=self.max_length, + mode=attention_mode + ) + elif attention == 'attention': + self.attention = Attention( + max_length=self.max_length, + n_feature=8 * 32, + ) + else: + raise ValueError(f'invalid attention: {attention}') + + self.cls = nn.Linear(self.out_channels, num_classes) + + def forward(self, images): + features = self.backbone(images) # (N, E, H, W) + attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W) + logits = self.cls(attn_vecs) # (N, T, C) + pt_lengths = self._get_length(logits) + + return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths, + 'attn_scores': attn_scores, 'loss_weight': self.loss_weight, 'name': 'vision'} diff --git a/strhub/models/abinet/resnet.py b/strhub/models/abinet/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..59bf38896987b3560e254e8037426d29bcdd5844 --- /dev/null +++ b/strhub/models/abinet/resnet.py @@ -0,0 +1,72 @@ +import math +from typing import Optional, Callable + +import torch.nn as nn +from torchvision.models import resnet + + +class BasicBlock(resnet.BasicBlock): + + def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, + groups: int = 1, base_width: int = 64, dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None) -> None: + super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer) + self.conv1 = resnet.conv1x1(inplanes, planes) + self.conv2 = resnet.conv3x3(planes, planes, stride) + + +class ResNet(nn.Module): + + def __init__(self, block, layers): + super().__init__() + self.inplanes = 32 + self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, 32, layers[0], stride=2) + self.layer2 = self._make_layer(block, 64, layers[1], stride=1) + self.layer3 = self._make_layer(block, 128, layers[2], stride=2) + self.layer4 = self._make_layer(block, 256, layers[3], stride=1) + self.layer5 = self._make_layer(block, 512, layers[4], stride=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.layer5(x) + return x + + +def resnet45(): + return ResNet(BasicBlock, [3, 4, 6, 6, 3]) diff --git a/strhub/models/abinet/system.py b/strhub/models/abinet/system.py new file mode 100644 index 0000000000000000000000000000000000000000..fc8f2cd8d274f45007f97301a555d884398a2e97 --- /dev/null +++ b/strhub/models/abinet/system.py @@ -0,0 +1,172 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +from typing import Any, Tuple, List, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.optim import AdamW +from torch.optim.lr_scheduler import OneCycleLR + +from pytorch_lightning.utilities.types import STEP_OUTPUT +from timm.optim.optim_factory import param_groups_weight_decay + +from strhub.models.base import CrossEntropySystem +from strhub.models.utils import init_weights +from .model_abinet_iter import ABINetIterModel as Model + +log = logging.getLogger(__name__) + + +class ABINet(CrossEntropySystem): + + def __init__(self, charset_train: str, charset_test: str, max_label_length: int, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float, + iter_size: int, d_model: int, nhead: int, d_inner: int, dropout: float, activation: str, + v_loss_weight: float, v_attention: str, v_attention_mode: str, v_backbone: str, v_num_layers: int, + l_loss_weight: float, l_num_layers: int, l_detach: bool, l_use_self_attn: bool, + l_lr: float, a_loss_weight: float, lm_only: bool = False, **kwargs) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.scheduler = None + self.save_hyperparameters() + self.max_label_length = max_label_length + self.num_classes = len(self.tokenizer) - 2 # We don't predict nor + self.model = Model(max_label_length, self.eos_id, self.num_classes, iter_size, d_model, nhead, d_inner, + dropout, activation, v_loss_weight, v_attention, v_attention_mode, v_backbone, v_num_layers, + l_loss_weight, l_num_layers, l_detach, l_use_self_attn, a_loss_weight) + self.model.apply(init_weights) + # FIXME: doesn't support resumption from checkpoint yet + self._reset_alignment = True + self._reset_optimizers = True + self.l_lr = l_lr + self.lm_only = lm_only + # Train LM only. Freeze other submodels. + if lm_only: + self.l_lr = lr # for tuning + self.model.vision.requires_grad_(False) + self.model.alignment.requires_grad_(False) + + @property + def _pretraining(self): + # In the original work, VM was pretrained for 8 epochs while full model was trained for an additional 10 epochs. + total_steps = self.trainer.estimated_stepping_batches * self.trainer.accumulate_grad_batches + return self.global_step < (8 / (8 + 10)) * total_steps + + @torch.jit.ignore + def no_weight_decay(self): + return {'model.language.proj.weight'} + + def _add_weight_decay(self, model: nn.Module, skip_list=()): + if self.weight_decay: + return param_groups_weight_decay(model, self.weight_decay, skip_list) + else: + return [{'params': model.parameters()}] + + def configure_optimizers(self): + agb = self.trainer.accumulate_grad_batches + # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP. + lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256. + lr = lr_scale * self.lr + l_lr = lr_scale * self.l_lr + params = [] + params.extend(self._add_weight_decay(self.model.vision)) + params.extend(self._add_weight_decay(self.model.alignment)) + # We use a different learning rate for the LM. + for p in self._add_weight_decay(self.model.language, ('proj.weight',)): + p['lr'] = l_lr + params.append(p) + max_lr = [p.get('lr', lr) for p in params] + optim = AdamW(params, lr) + self.scheduler = OneCycleLR(optim, max_lr, self.trainer.estimated_stepping_batches, + pct_start=self.warmup_pct, cycle_momentum=False) + return {'optimizer': optim, 'lr_scheduler': {'scheduler': self.scheduler, 'interval': 'step'}} + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + logits = self.model.forward(images)[0]['logits'] + return logits[:, :max_length + 1] # truncate + + def calc_loss(self, targets, *res_lists) -> Tensor: + total_loss = 0 + for res_list in res_lists: + loss = 0 + if isinstance(res_list, dict): + res_list = [res_list] + for res in res_list: + logits = res['logits'].flatten(end_dim=1) + loss += F.cross_entropy(logits, targets.flatten(), ignore_index=self.pad_id) + loss /= len(res_list) + self.log('loss_' + res_list[0]['name'], loss) + total_loss += res_list[0]['loss_weight'] * loss + return total_loss + + def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: + if not self._pretraining and self._reset_optimizers: + log.info('Pretraining ends. Updating base LRs.') + self._reset_optimizers = False + # Make base_lr the same for all groups + base_lr = self.scheduler.base_lrs[0] # base_lr of group 0 - VM + self.scheduler.base_lrs = [base_lr] * len(self.scheduler.base_lrs) + + def _prepare_inputs_and_targets(self, labels): + # Use dummy label to ensure sequence length is constant. + dummy = ['0' * self.max_label_length] + targets = self.tokenizer.encode(dummy + list(labels), self.device)[1:] + targets = targets[:, 1:] # remove . Unused here. + # Inputs are padded with eos_id + inputs = torch.where(targets == self.pad_id, self.eos_id, targets) + inputs = F.one_hot(inputs, self.num_classes).float() + lengths = torch.as_tensor(list(map(len, labels)), device=self.device) + 1 # +1 for eos + return inputs, lengths, targets + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + inputs, lengths, targets = self._prepare_inputs_and_targets(labels) + if self.lm_only: + l_res = self.model.language(inputs, lengths) + loss = self.calc_loss(targets, l_res) + # Pretrain submodels independently first + elif self._pretraining: + # Vision + v_res = self.model.vision(images) + # Language + l_res = self.model.language(inputs, lengths) + # We also train the alignment model to 'satisfy' DDP requirements (all parameters should be used). + # We'll reset its parameters prior to joint training. + a_res = self.model.alignment(l_res['feature'].detach(), v_res['feature'].detach()) + loss = self.calc_loss(targets, v_res, l_res, a_res) + else: + # Reset alignment model's parameters once prior to full model training. + if self._reset_alignment: + log.info('Pretraining ends. Resetting alignment model.') + self._reset_alignment = False + self.model.alignment.apply(init_weights) + all_a_res, all_l_res, v_res = self.model.forward(images) + loss = self.calc_loss(targets, v_res, all_l_res, all_a_res) + self.log('loss', loss) + return loss + + def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]: + if self.lm_only: + inputs, lengths, targets = self._prepare_inputs_and_targets(labels) + l_res = self.model.language(inputs, lengths) + loss = self.calc_loss(targets, l_res) + loss_numel = (targets != self.pad_id).sum() + return l_res['logits'], loss, loss_numel + else: + return super().forward_logits_loss(images, labels) diff --git a/strhub/models/abinet/transformer.py b/strhub/models/abinet/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a920805d67eea5671675f5623a47d06ad8af894b --- /dev/null +++ b/strhub/models/abinet/transformer.py @@ -0,0 +1,143 @@ +import math + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.modules.transformer import _get_activation_fn + + +class TransformerDecoderLayer(nn.Module): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", self_attn=True, siamese=False, debug=False): + super().__init__() + self.has_self_attn, self.siamese = self_attn, siamese + self.debug = debug + if self.has_self_attn: + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.norm1 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + if self.siamese: + self.multihead_attn2 = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super().__setstate__(state) + + def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, + tgt_key_padding_mask=None, memory_key_padding_mask=None, + memory2=None, memory_mask2=None, memory_key_padding_mask2=None): + # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor + r"""Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + if self.has_self_attn: + tgt2, attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + if self.debug: self.attn = attn + tgt2, attn2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + if self.debug: self.attn2 = attn2 + + if self.siamese: + tgt3, attn3 = self.multihead_attn2(tgt, memory2, memory2, attn_mask=memory_mask2, + key_padding_mask=memory_key_padding_mask2) + tgt = tgt + self.dropout2(tgt3) + if self.debug: self.attn3 = attn3 + + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + + return tgt + + +class PositionalEncoding(nn.Module): + r"""Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + r"""Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + + x = x + self.pe[:x.size(0), :] + return self.dropout(x) diff --git a/strhub/models/base.py b/strhub/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a9fafe450f615f027503522ff5c81443fe29aa4a --- /dev/null +++ b/strhub/models/base.py @@ -0,0 +1,202 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Tuple, List + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from nltk import edit_distance +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from timm.optim import create_optimizer_v2 +from torch import Tensor +from torch.optim import Optimizer +from torch.optim.lr_scheduler import OneCycleLR + +from strhub.data.utils import CharsetAdapter, CTCTokenizer, Tokenizer, BaseTokenizer + + +@dataclass +class BatchResult: + num_samples: int + correct: int + ned: float + confidence: float + label_length: int + loss: Tensor + loss_numel: int + + +class BaseSystem(pl.LightningModule, ABC): + + def __init__(self, tokenizer: BaseTokenizer, charset_test: str, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None: + super().__init__() + self.tokenizer = tokenizer + self.charset_adapter = CharsetAdapter(charset_test) + self.batch_size = batch_size + self.lr = lr + self.warmup_pct = warmup_pct + self.weight_decay = weight_decay + + @abstractmethod + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + """Inference + + Args: + images: Batch of images. Shape: N, Ch, H, W + max_length: Max sequence length of the output. If None, will use default. + + Returns: + logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials) + """ + raise NotImplementedError + + @abstractmethod + def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]: + """Like forward(), but also computes the loss (calls forward() internally). + + Args: + images: Batch of images. Shape: N, Ch, H, W + labels: Text labels of the images + + Returns: + logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials) + loss: mean loss for the batch + loss_numel: number of elements the loss was calculated from + """ + raise NotImplementedError + + def configure_optimizers(self): + agb = self.trainer.accumulate_grad_batches + # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP. + lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256. + lr = lr_scale * self.lr + optim = create_optimizer_v2(self, 'adamw', lr, self.weight_decay) + sched = OneCycleLR(optim, lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, + cycle_momentum=False) + return {'optimizer': optim, 'lr_scheduler': {'scheduler': sched, 'interval': 'step'}} + + def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): + optimizer.zero_grad(set_to_none=True) + + def _eval_step(self, batch, validation: bool) -> Optional[STEP_OUTPUT]: + images, labels = batch + + correct = 0 + total = 0 + ned = 0 + confidence = 0 + label_length = 0 + if validation: + logits, loss, loss_numel = self.forward_logits_loss(images, labels) + else: + # At test-time, we shouldn't specify a max_label_length because the test-time charset used + # might be different from the train-time charset. max_label_length in eval_logits_loss() is computed + # based on the transformed label, which could be wrong if the actual gt label contains characters existing + # in the train-time charset but not in the test-time charset. For example, "aishahaleyes.blogspot.com" + # is exactly 25 characters, but if processed by CharsetAdapter for the 36-char set, it becomes 23 characters + # long only, which sets max_label_length = 23. This will cause the model prediction to be truncated. + logits = self.forward(images) + loss = loss_numel = None # Only used for validation; not needed at test-time. + + probs = logits.softmax(-1) + preds, probs = self.tokenizer.decode(probs) + for pred, prob, gt in zip(preds, probs, labels): + confidence += prob.prod().item() + pred = self.charset_adapter(pred) + # Follow ICDAR 2019 definition of N.E.D. + ned += edit_distance(pred, gt) / max(len(pred), len(gt)) + if pred == gt: + correct += 1 + total += 1 + label_length += len(pred) + return dict(output=BatchResult(total, correct, ned, confidence, label_length, loss, loss_numel)) + + @staticmethod + def _aggregate_results(outputs: EPOCH_OUTPUT) -> Tuple[float, float, float]: + if not outputs: + return 0., 0., 0. + total_loss = 0 + total_loss_numel = 0 + total_n_correct = 0 + total_norm_ED = 0 + total_size = 0 + for result in outputs: + result = result['output'] + total_loss += result.loss_numel * result.loss + total_loss_numel += result.loss_numel + total_n_correct += result.correct + total_norm_ED += result.ned + total_size += result.num_samples + acc = total_n_correct / total_size + ned = (1 - total_norm_ED / total_size) + loss = total_loss / total_loss_numel + return acc, ned, loss + + def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]: + return self._eval_step(batch, True) + + def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + acc, ned, loss = self._aggregate_results(outputs) + self.log('val_accuracy', 100 * acc, sync_dist=True) + self.log('val_NED', 100 * ned, sync_dist=True) + self.log('val_loss', loss, sync_dist=True) + self.log('hp_metric', acc, sync_dist=True) + + def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]: + return self._eval_step(batch, False) + + +class CrossEntropySystem(BaseSystem): + + def __init__(self, charset_train: str, charset_test: str, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None: + tokenizer = Tokenizer(charset_train) + super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.bos_id = tokenizer.bos_id + self.eos_id = tokenizer.eos_id + self.pad_id = tokenizer.pad_id + + def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]: + targets = self.tokenizer.encode(labels, self.device) + targets = targets[:, 1:] # Discard + max_len = targets.shape[1] - 1 # exclude from count + logits = self.forward(images, max_len) + loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id) + loss_numel = (targets != self.pad_id).sum() + return logits, loss, loss_numel + + +class CTCSystem(BaseSystem): + + def __init__(self, charset_train: str, charset_test: str, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None: + tokenizer = CTCTokenizer(charset_train) + super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.blank_id = tokenizer.blank_id + + def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]: + targets = self.tokenizer.encode(labels, self.device) + logits = self.forward(images) + log_probs = logits.log_softmax(-1).transpose(0, 1) # swap batch and seq. dims + T, N, _ = log_probs.shape + input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long, device=self.device) + target_lengths = torch.as_tensor(list(map(len, labels)), dtype=torch.long, device=self.device) + loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_id, zero_infinity=True) + return logits, loss, N diff --git a/strhub/models/crnn/LICENSE b/strhub/models/crnn/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f98687be392fdce266708e79885aadaa4991b67f --- /dev/null +++ b/strhub/models/crnn/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2017 Jieru Mei + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/strhub/models/crnn/__init__.py b/strhub/models/crnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4535947d9233c8fb0a85e9c22b151697d37f410 --- /dev/null +++ b/strhub/models/crnn/__init__.py @@ -0,0 +1,13 @@ +r""" +Shi, Baoguang, Xiang Bai, and Cong Yao. +"An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition." +IEEE transactions on pattern analysis and machine intelligence 39, no. 11 (2016): 2298-2304. + +https://arxiv.org/abs/1507.05717 + +All source files, except `system.py`, are based on the implementation listed below, +and hence are released under the license of the original. + +Source: https://github.com/meijieru/crnn.pytorch +License: MIT License (see included LICENSE file) +""" diff --git a/strhub/models/crnn/model.py b/strhub/models/crnn/model.py new file mode 100644 index 0000000000000000000000000000000000000000..1a71845fba242c3c63c15a79cf43134a35807453 --- /dev/null +++ b/strhub/models/crnn/model.py @@ -0,0 +1,62 @@ +import torch.nn as nn + +from strhub.models.modules import BidirectionalLSTM + + +class CRNN(nn.Module): + + def __init__(self, img_h, nc, nclass, nh, leaky_relu=False): + super().__init__() + assert img_h % 16 == 0, 'img_h has to be a multiple of 16' + + ks = [3, 3, 3, 3, 3, 3, 2] + ps = [1, 1, 1, 1, 1, 1, 0] + ss = [1, 1, 1, 1, 1, 1, 1] + nm = [64, 128, 256, 256, 512, 512, 512] + + cnn = nn.Sequential() + + def convRelu(i, batchNormalization=False): + nIn = nc if i == 0 else nm[i - 1] + nOut = nm[i] + cnn.add_module('conv{0}'.format(i), + nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i], bias=not batchNormalization)) + if batchNormalization: + cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) + if leaky_relu: + cnn.add_module('relu{0}'.format(i), + nn.LeakyReLU(0.2, inplace=True)) + else: + cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) + + convRelu(0) + cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 + convRelu(1) + cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 + convRelu(2, True) + convRelu(3) + cnn.add_module('pooling{0}'.format(2), + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 + convRelu(4, True) + convRelu(5) + cnn.add_module('pooling{0}'.format(3), + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 + convRelu(6, True) # 512x1x16 + + self.cnn = cnn + self.rnn = nn.Sequential( + BidirectionalLSTM(512, nh, nh), + BidirectionalLSTM(nh, nh, nclass)) + + def forward(self, input): + # conv features + conv = self.cnn(input) + b, c, h, w = conv.size() + assert h == 1, 'the height of conv must be 1' + conv = conv.squeeze(2) + conv = conv.transpose(1, 2) # [b, w, c] + + # rnn features + output = self.rnn(conv) + + return output diff --git a/strhub/models/crnn/system.py b/strhub/models/crnn/system.py new file mode 100644 index 0000000000000000000000000000000000000000..abcb28e5c29f4ca484b87e283d1e4615e56378f2 --- /dev/null +++ b/strhub/models/crnn/system.py @@ -0,0 +1,43 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Optional + +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch import Tensor + +from strhub.models.base import CTCSystem +from strhub.models.utils import init_weights +from .model import CRNN as Model + + +class CRNN(CTCSystem): + + def __init__(self, charset_train: str, charset_test: str, max_label_length: int, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float, + img_size: Sequence[int], hidden_size: int, leaky_relu: bool, **kwargs) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + self.model = Model(img_size[0], 3, len(self.tokenizer), hidden_size, leaky_relu) + self.model.apply(init_weights) + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + return self.model.forward(images) + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + loss = self.forward_logits_loss(images, labels)[1] + self.log('loss', loss) + return loss diff --git a/strhub/models/modules.py b/strhub/models/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..a89d05f6afd67437f3cfa8aff6d2d8b12df3fafa --- /dev/null +++ b/strhub/models/modules.py @@ -0,0 +1,20 @@ +r"""Shared modules used by CRNN and TRBA""" +from torch import nn + + +class BidirectionalLSTM(nn.Module): + """Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py""" + + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) + self.linear = nn.Linear(hidden_size * 2, output_size) + + def forward(self, input): + """ + input : visual feature [batch_size x T x input_size], T = num_steps. + output : contextual feature [batch_size x T x output_size] + """ + recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) + output = self.linear(recurrent) # batch_size x T x output_size + return output diff --git a/strhub/models/parseq/__init__.py b/strhub/models/parseq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/strhub/models/parseq/__pycache__/__init__.cpython-37.pyc b/strhub/models/parseq/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e9ee673e7bb209273c806d5543799fe4db28c73 Binary files /dev/null and b/strhub/models/parseq/__pycache__/__init__.cpython-37.pyc differ diff --git a/strhub/models/parseq/__pycache__/modules.cpython-37.pyc b/strhub/models/parseq/__pycache__/modules.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9908253475ce05afd36ce1a366199055d46a6194 Binary files /dev/null and b/strhub/models/parseq/__pycache__/modules.cpython-37.pyc differ diff --git a/strhub/models/parseq/__pycache__/system.cpython-37.pyc b/strhub/models/parseq/__pycache__/system.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70e2abdffe375632e8db656a3e63e816b1b4d16b Binary files /dev/null and b/strhub/models/parseq/__pycache__/system.cpython-37.pyc differ diff --git a/strhub/models/parseq/modules.py b/strhub/models/parseq/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..a7fdbd60bc5f05978448d4d00b73cb917bfea796 --- /dev/null +++ b/strhub/models/parseq/modules.py @@ -0,0 +1,126 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +from torch import nn as nn, Tensor +from torch.nn import functional as F +from torch.nn.modules import transformer + +from timm.models.vision_transformer import VisionTransformer, PatchEmbed + + +class DecoderLayer(nn.Module): + """A Transformer decoder layer supporting two-stream attention (XLNet) + This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch.""" + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', + layer_norm_eps=1e-5): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) + self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = transformer._get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.gelu + super().__setstate__(state) + + def forward_stream(self, tgt: Tensor, tgt_norm: Tensor, tgt_kv: Tensor, memory: Tensor, tgt_mask: Optional[Tensor], + tgt_key_padding_mask: Optional[Tensor]): + """Forward pass for a single stream (i.e. content or query) + tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency. + Both tgt_kv and memory are expected to be LayerNorm'd too. + memory is LayerNorm'd by ViT. + """ + tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + tgt = tgt + self.dropout1(tgt2) + + tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory) + tgt = tgt + self.dropout2(tgt2) + + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt))))) + tgt = tgt + self.dropout3(tgt2) + return tgt, sa_weights, ca_weights + + def forward(self, query, content, memory, query_mask: Optional[Tensor] = None, content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None, update_content: bool = True): + query_norm = self.norm_q(query) + content_norm = self.norm_c(content) + query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0] + if update_content: + content = self.forward_stream(content, content_norm, content_norm, memory, content_mask, + content_key_padding_mask)[0] + return query, content + + +class Decoder(nn.Module): + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm): + super().__init__() + self.layers = transformer._get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, query, content, memory, query_mask: Optional[Tensor] = None, content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None): + for i, mod in enumerate(self.layers): + last = i == len(self.layers) - 1 + query, content = mod(query, content, memory, query_mask, content_mask, content_key_padding_mask, + update_content=not last) + query = self.norm(query) + return query + + +class Encoder(VisionTransformer): + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., + qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed): + super().__init__(img_size, patch_size, in_chans, embed_dim=embed_dim, depth=depth, num_heads=num_heads, + mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, embed_layer=embed_layer, + num_classes=0, global_pool='', class_token=False) # these disable the classifier head + + def forward(self, x): + # Return all tokens + return self.forward_features(x) + + +class TokenEmbedding(nn.Module): + + def __init__(self, charset_size: int, embed_dim: int): + super().__init__() + self.embedding = nn.Embedding(charset_size, embed_dim) + self.embed_dim = embed_dim + + def forward(self, tokens: torch.Tensor): + return math.sqrt(self.embed_dim) * self.embedding(tokens) diff --git a/strhub/models/parseq/system.py b/strhub/models/parseq/system.py new file mode 100644 index 0000000000000000000000000000000000000000..8d3e0be722819d0729d1c9074fb50ee910f457d6 --- /dev/null +++ b/strhub/models/parseq/system.py @@ -0,0 +1,259 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import partial +from itertools import permutations +from typing import Sequence, Any, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from pytorch_lightning.utilities.types import STEP_OUTPUT +from timm.models.helpers import named_apply + +from strhub.models.base import CrossEntropySystem +from strhub.models.utils import init_weights +from .modules import DecoderLayer, Decoder, Encoder, TokenEmbedding + + +class PARSeq(CrossEntropySystem): + + def __init__(self, charset_train: str, charset_test: str, max_label_length: int, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float, + img_size: Sequence[int], patch_size: Sequence[int], embed_dim: int, + enc_num_heads: int, enc_mlp_ratio: int, enc_depth: int, + dec_num_heads: int, dec_mlp_ratio: int, dec_depth: int, + perm_num: int, perm_forward: bool, perm_mirrored: bool, + decode_ar: bool, refine_iters: int, dropout: float, **kwargs: Any) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + + self.max_label_length = max_label_length + self.decode_ar = decode_ar + self.refine_iters = refine_iters + + self.encoder = Encoder(img_size, patch_size, embed_dim=embed_dim, depth=enc_depth, num_heads=enc_num_heads, + mlp_ratio=enc_mlp_ratio) + decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) + self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=nn.LayerNorm(embed_dim)) + + # Perm/attn mask stuff + self.rng = np.random.default_rng() + self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num + self.perm_forward = perm_forward + self.perm_mirrored = perm_mirrored + + # We don't predict nor + self.head = nn.Linear(embed_dim, len(self.tokenizer) - 2) + self.text_embed = TokenEmbedding(len(self.tokenizer), embed_dim) + + # +1 for + self.pos_queries = nn.Parameter(torch.Tensor(1, max_label_length + 1, embed_dim)) + self.dropout = nn.Dropout(p=dropout) + # Encoder has its own init. + named_apply(partial(init_weights, exclude=['encoder']), self) + nn.init.trunc_normal_(self.pos_queries, std=.02) + + @torch.jit.ignore + def no_weight_decay(self): + param_names = {'text_embed.embedding.weight', 'pos_queries'} + enc_param_names = {'encoder.' + n for n in self.encoder.no_weight_decay()} + return param_names.union(enc_param_names) + + def encode(self, img: torch.Tensor): + return self.encoder(img) + + def decode(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[Tensor] = None, + tgt_padding_mask: Optional[Tensor] = None, tgt_query: Optional[Tensor] = None, + tgt_query_mask: Optional[Tensor] = None): + N, L = tgt.shape + # stands for the null context. We only supply position information for characters after . + null_ctx = self.text_embed(tgt[:, :1]) + tgt_emb = self.pos_queries[:, :L - 1] + self.text_embed(tgt[:, 1:]) + tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1)) + if tgt_query is None: + tgt_query = self.pos_queries[:, :L].expand(N, -1, -1) + tgt_query = self.dropout(tgt_query) + return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask) + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + testing = max_length is None + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + bs = images.shape[0] + # +1 for at end of sequence. + num_steps = max_length + 1 + memory = self.encode(images) + + # Query positions up to `num_steps` + pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1) + + # Special case for the forward permutation. Faster than using `generate_attn_masks()` + tgt_mask = query_mask = torch.triu(torch.full((num_steps, num_steps), float('-inf'), device=self._device), 1) + + if self.decode_ar: + tgt_in = torch.full((bs, num_steps), self.pad_id, dtype=torch.long, device=self._device) + tgt_in[:, 0] = self.bos_id + + logits = [] + for i in range(num_steps): + j = i + 1 # next token index + # Efficient decoding: + # Input the context up to the ith token. We use only one query (at position = i) at a time. + # This works because of the lookahead masking effect of the canonical (forward) AR context. + # Past tokens have no access to future tokens, hence are fixed once computed. + tgt_out = self.decode(tgt_in[:, :j], memory, tgt_mask[:j, :j], tgt_query=pos_queries[:, i:j], + tgt_query_mask=query_mask[i:j, :j]) + # the next token probability is in the output's ith token position + p_i = self.head(tgt_out) + logits.append(p_i) + if j < num_steps: + # greedy decode. add the next token index to the target input + tgt_in[:, j] = p_i.squeeze().argmax(-1) + # Efficient batch decoding: If all output words have at least one EOS token, end decoding. + if testing and (tgt_in == self.eos_id).any(dim=-1).all(): + break + + logits = torch.cat(logits, dim=1) + else: + # No prior context, so input is just . We query all positions. + tgt_in = torch.full((bs, 1), self.bos_id, dtype=torch.long, device=self._device) + tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries) + logits = self.head(tgt_out) + + if self.refine_iters: + # For iterative refinement, we always use a 'cloze' mask. + # We can derive it from the AR forward mask by unmasking the token context to the right. + query_mask[torch.triu(torch.ones(num_steps, num_steps, dtype=torch.bool, device=self._device), 2)] = 0 + bos = torch.full((bs, 1), self.bos_id, dtype=torch.long, device=self._device) + for i in range(self.refine_iters): + # Prior context is the previous output. + tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1) + tgt_padding_mask = ((tgt_in == self.eos_id).int().cumsum(-1) > 0) # mask tokens beyond the first EOS token. + tgt_out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, + tgt_query=pos_queries, tgt_query_mask=query_mask[:, :tgt_in.shape[1]]) + logits = self.head(tgt_out) + + return logits + + def gen_tgt_perms(self, tgt): + """Generate shared permutations for the whole batch. + This works because the same attention mask can be used for the shorter sequences + because of the padding mask. + """ + # We don't permute the position of BOS, we permute EOS separately + max_num_chars = tgt.shape[1] - 2 + # Special handling for 1-character sequences + if max_num_chars == 1: + return torch.arange(3, device=self._device).unsqueeze(0) + perms = [torch.arange(max_num_chars, device=self._device)] if self.perm_forward else [] + # Additional permutations if needed + max_perms = math.factorial(max_num_chars) + if self.perm_mirrored: + max_perms //= 2 + num_gen_perms = min(self.max_gen_perms, max_perms) + # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions + # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars. + if max_num_chars < 5: + # Pool of permutations to sample from. We only need the first half (if complementary option is selected) + # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves + if max_num_chars == 4 and self.perm_mirrored: + selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21] + else: + selector = list(range(max_perms)) + perm_pool = torch.as_tensor(list(permutations(range(max_num_chars), max_num_chars)), device=self._device)[selector] + # If the forward permutation is always selected, no need to add it to the pool for sampling + if self.perm_forward: + perm_pool = perm_pool[1:] + perms = torch.stack(perms) + if len(perm_pool): + i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(perms), replace=False) + perms = torch.cat([perms, perm_pool[i]]) + else: + perms.extend([torch.randperm(max_num_chars, device=self._device) for _ in range(num_gen_perms - len(perms))]) + perms = torch.stack(perms) + if self.perm_mirrored: + # Add complementary pairs + comp = perms.flip(-1) + # Stack in such a way that the pairs are next to each other. + perms = torch.stack([perms, comp]).transpose(0, 1).reshape(-1, max_num_chars) + # NOTE: + # The only meaningful way of permuting the EOS position is by moving it one character position at a time. + # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS + # positions will always be much less than the number of permutations (unless a low perm_num is set). + # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly + # distribute it across the chosen number of permutations. + # Add position indices of BOS and EOS + bos_idx = perms.new_zeros((len(perms), 1)) + eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1) + perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1) + # Special handling for the reverse direction. This does two things: + # 1. Reverse context for the characters + # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode) + if len(perms) > 1: + perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=self._device) + return perms + + def generate_attn_masks(self, perm): + """Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens) + :param perm: the permutation sequence. i = 0 is always the BOS + :return: lookahead attention masks + """ + sz = perm.shape[0] + mask = torch.zeros((sz, sz), device=self._device) + for i in range(sz): + query_idx = perm[i] + masked_keys = perm[i + 1:] + mask[query_idx, masked_keys] = float('-inf') + content_mask = mask[:-1, :-1].clone() + mask[torch.eye(sz, dtype=torch.bool, device=self._device)] = float('-inf') # mask "self" + query_mask = mask[1:, :-1] + return content_mask, query_mask + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + tgt = self.tokenizer.encode(labels, self._device) + + # Encode the source sequence (i.e. the image codes) + memory = self.encode(images) + + # Prepare the target sequences (input and output) + tgt_perms = self.gen_tgt_perms(tgt) + tgt_in = tgt[:, :-1] + tgt_out = tgt[:, 1:] + # The [EOS] token is not depended upon by any other token in any permutation ordering + tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id) + + loss = 0 + loss_numel = 0 + n = (tgt_out != self.pad_id).sum().item() + for i, perm in enumerate(tgt_perms): + tgt_mask, query_mask = self.generate_attn_masks(perm) + out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask) + logits = self.head(out).flatten(end_dim=1) + loss += n * F.cross_entropy(logits, tgt_out.flatten(), ignore_index=self.pad_id) + loss_numel += n + # After the second iteration (i.e. done with canonical and reverse orderings), + # remove the [EOS] tokens for the succeeding perms + if i == 1: + tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id, tgt_out) + n = (tgt_out != self.pad_id).sum().item() + loss /= loss_numel + + self.log('loss', loss) + return loss diff --git a/strhub/models/trba/__init__.py b/strhub/models/trba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a574a8af95e7f1ffaa05c45b4cd22f4a3cc0a5c0 --- /dev/null +++ b/strhub/models/trba/__init__.py @@ -0,0 +1,13 @@ +r""" +Baek, Jeonghun, Geewook Kim, Junyeop Lee, Sungrae Park, Dongyoon Han, Sangdoo Yun, Seong Joon Oh, and Hwalsuk Lee. +"What is wrong with scene text recognition model comparisons? dataset and model analysis." +In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4715-4723. 2019. + +https://arxiv.org/abs/1904.01906 + +All source files, except `system.py`, are based on the implementation listed below, +and hence are released under the license of the original. + +Source: https://github.com/clovaai/deep-text-recognition-benchmark +License: Apache License 2.0 (see LICENSE file in project root) +""" diff --git a/strhub/models/trba/feature_extraction.py b/strhub/models/trba/feature_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..17646e3ff83ad28c1021237824a838e38c3b6345 --- /dev/null +++ b/strhub/models/trba/feature_extraction.py @@ -0,0 +1,110 @@ +import torch.nn as nn + +from torchvision.models.resnet import BasicBlock + + +class ResNet_FeatureExtractor(nn.Module): + """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super().__init__() + self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) + + def forward(self, input): + return self.ConvNet(input) + + +class ResNet(nn.Module): + + def __init__(self, input_channel, output_channel, block, layers): + super().__init__() + + self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] + + self.inplanes = int(output_channel / 8) + self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) + self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_2 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU(inplace=True) + + self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) + self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ + 0], kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) + + self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) + self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ + 1], kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) + + self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) + self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) + self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ + 2], kernel_size=3, stride=1, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) + + self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) + self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) + self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) + self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=1, padding=0, bias=False) + self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv0_1(x) + x = self.bn0_1(x) + x = self.relu(x) + x = self.conv0_2(x) + x = self.bn0_2(x) + x = self.relu(x) + + x = self.maxpool1(x) + x = self.layer1(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.maxpool2(x) + x = self.layer2(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.maxpool3(x) + x = self.layer3(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.layer4(x) + x = self.conv4_1(x) + x = self.bn4_1(x) + x = self.relu(x) + x = self.conv4_2(x) + x = self.bn4_2(x) + x = self.relu(x) + + return x diff --git a/strhub/models/trba/model.py b/strhub/models/trba/model.py new file mode 100644 index 0000000000000000000000000000000000000000..41161a4df4e2ff368bfe1c62f681c6964510a0c0 --- /dev/null +++ b/strhub/models/trba/model.py @@ -0,0 +1,55 @@ +import torch.nn as nn + +from strhub.models.modules import BidirectionalLSTM +from .feature_extraction import ResNet_FeatureExtractor +from .prediction import Attention +from .transformation import TPS_SpatialTransformerNetwork + + +class TRBA(nn.Module): + + def __init__(self, img_h, img_w, num_class, num_fiducial=20, input_channel=3, output_channel=512, hidden_size=256, + use_ctc=False): + super().__init__() + """ Transformation """ + self.Transformation = TPS_SpatialTransformerNetwork( + F=num_fiducial, I_size=(img_h, img_w), I_r_size=(img_h, img_w), + I_channel_num=input_channel) + + """ FeatureExtraction """ + self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel) + self.FeatureExtraction_output = output_channel + self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 + + """ Sequence modeling""" + self.SequenceModeling = nn.Sequential( + BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size), + BidirectionalLSTM(hidden_size, hidden_size, hidden_size)) + self.SequenceModeling_output = hidden_size + + """ Prediction """ + if use_ctc: + self.Prediction = nn.Linear(self.SequenceModeling_output, num_class) + else: + self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class) + + def forward(self, image, max_label_length, text=None): + """ Transformation stage """ + image = self.Transformation(image) + + """ Feature extraction stage """ + visual_feature = self.FeatureExtraction(image) + visual_feature = visual_feature.permute(0, 3, 1, 2) # [b, c, h, w] -> [b, w, c, h] + visual_feature = self.AdaptiveAvgPool(visual_feature) # [b, w, c, h] -> [b, w, c, 1] + visual_feature = visual_feature.squeeze(3) # [b, w, c, 1] -> [b, w, c] + + """ Sequence modeling stage """ + contextual_feature = self.SequenceModeling(visual_feature) # [b, num_steps, hidden_size] + + """ Prediction stage """ + if isinstance(self.Prediction, Attention): + prediction = self.Prediction(contextual_feature.contiguous(), text, max_label_length) + else: + prediction = self.Prediction(contextual_feature.contiguous()) # CTC + + return prediction # [b, num_steps, num_class] diff --git a/strhub/models/trba/prediction.py b/strhub/models/trba/prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..5609398a28ef5288d3f3971786c2cebc2e574336 --- /dev/null +++ b/strhub/models/trba/prediction.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Attention(nn.Module): + + def __init__(self, input_size, hidden_size, num_class, num_char_embeddings=256): + super().__init__() + self.attention_cell = AttentionCell(input_size, hidden_size, num_char_embeddings) + self.hidden_size = hidden_size + self.num_class = num_class + self.generator = nn.Linear(hidden_size, num_class) + self.char_embeddings = nn.Embedding(num_class, num_char_embeddings) + + def forward(self, batch_H, text, max_label_length=25): + """ + input: + batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_class] + text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [SOS] token. text[:, 0] = [SOS]. + output: probability distribution at each step [batch_size x num_steps x num_class] + """ + batch_size = batch_H.size(0) + num_steps = max_label_length + 1 # +1 for [EOS] at end of sentence. + + output_hiddens = batch_H.new_zeros((batch_size, num_steps, self.hidden_size), dtype=torch.float) + hidden = (batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float), + batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float)) + + if self.training: + for i in range(num_steps): + char_embeddings = self.char_embeddings(text[:, i]) + # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_embeddings : f(y_{t-1}) + hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings) + output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) + probs = self.generator(output_hiddens) + + else: + targets = text[0].expand(batch_size) # should be fill with [SOS] token + probs = batch_H.new_zeros((batch_size, num_steps, self.num_class), dtype=torch.float) + + for i in range(num_steps): + char_embeddings = self.char_embeddings(targets) + hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings) + probs_step = self.generator(hidden[0]) + probs[:, i, :] = probs_step + _, next_input = probs_step.max(1) + targets = next_input + + return probs # batch_size x num_steps x num_class + + +class AttentionCell(nn.Module): + + def __init__(self, input_size, hidden_size, num_embeddings): + super().__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias=False) + self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias + self.score = nn.Linear(hidden_size, 1, bias=False) + self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_embeddings): + # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) + e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 + + alpha = F.softmax(e, dim=1) + context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel + concat_context = torch.cat([context, char_embeddings], 1) # batch_size x (num_channel + num_embedding) + cur_hidden = self.rnn(concat_context, prev_hidden) + return cur_hidden, alpha diff --git a/strhub/models/trba/system.py b/strhub/models/trba/system.py new file mode 100644 index 0000000000000000000000000000000000000000..31bbb6d44e6eabf47402ae998ffbe4de8fc427a2 --- /dev/null +++ b/strhub/models/trba/system.py @@ -0,0 +1,87 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Sequence, Any, Optional + +import torch +import torch.nn.functional as F +from pytorch_lightning.utilities.types import STEP_OUTPUT +from timm.models.helpers import named_apply +from torch import Tensor + +from strhub.models.base import CrossEntropySystem, CTCSystem +from strhub.models.utils import init_weights +from .model import TRBA as Model + + +class TRBA(CrossEntropySystem): + + def __init__(self, charset_train: str, charset_test: str, max_label_length: int, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float, + img_size: Sequence[int], num_fiducial: int, output_channel: int, hidden_size: int, + **kwargs: Any) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + self.max_label_length = max_label_length + img_h, img_w = img_size + self.model = Model(img_h, img_w, len(self.tokenizer), num_fiducial, + output_channel=output_channel, hidden_size=hidden_size, use_ctc=False) + named_apply(partial(init_weights, exclude=['Transformation.LocalizationNetwork.localization_fc2']), self.model) + + @torch.jit.ignore + def no_weight_decay(self): + return {'model.Prediction.char_embeddings.weight'} + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + text = images.new_full([1], self.bos_id, dtype=torch.long) + return self.model.forward(images, max_length, text) + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + encoded = self.tokenizer.encode(labels, self.device) + inputs = encoded[:, :-1] # remove + targets = encoded[:, 1:] # remove + max_length = encoded.shape[1] - 2 # exclude and from count + logits = self.model.forward(images, max_length, inputs) + loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id) + self.log('loss', loss) + return loss + + +class TRBC(CTCSystem): + + def __init__(self, charset_train: str, charset_test: str, max_label_length: int, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float, + img_size: Sequence[int], num_fiducial: int, output_channel: int, hidden_size: int, + **kwargs: Any) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + self.max_label_length = max_label_length + img_h, img_w = img_size + self.model = Model(img_h, img_w, len(self.tokenizer), num_fiducial, + output_channel=output_channel, hidden_size=hidden_size, use_ctc=True) + named_apply(partial(init_weights, exclude=['Transformation.LocalizationNetwork.localization_fc2']), self.model) + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + # max_label_length is unused in CTC prediction + return self.model.forward(images, None) + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + loss = self.forward_logits_loss(images, labels)[1] + self.log('loss', loss) + return loss diff --git a/strhub/models/trba/transformation.py b/strhub/models/trba/transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..960419d135ec878aaaa3297c3ff5c22e998ef6be --- /dev/null +++ b/strhub/models/trba/transformation.py @@ -0,0 +1,169 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TPS_SpatialTransformerNetwork(nn.Module): + """ Rectification Network of RARE, namely TPS based STN """ + + def __init__(self, F, I_size, I_r_size, I_channel_num=1): + """ Based on RARE TPS + input: + batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] + I_size : (height, width) of the input image I + I_r_size : (height, width) of the rectified image I_r + I_channel_num : the number of channels of the input image I + output: + batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] + """ + super().__init__() + self.F = F + self.I_size = I_size + self.I_r_size = I_r_size # = (I_r_height, I_r_width) + self.I_channel_num = I_channel_num + self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) + self.GridGenerator = GridGenerator(self.F, self.I_r_size) + + def forward(self, batch_I): + batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 + # batch_size x n (= I_r_width x I_r_height) x 2 + build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) + build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) + + if torch.__version__ > "1.2.0": + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) + else: + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') + + return batch_I_r + + +class LocalizationNetwork(nn.Module): + """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ + + def __init__(self, F, I_channel_num): + super().__init__() + self.F = F + self.I_channel_num = I_channel_num + self.conv = nn.Sequential( + nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, + bias=False), nn.BatchNorm2d(64), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 + nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 + nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 + nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), + nn.AdaptiveAvgPool2d(1) # batch_size x 512 + ) + + self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) + self.localization_fc2 = nn.Linear(256, self.F * 2) + + # Init fc2 in LocalizationNetwork + self.localization_fc2.weight.data.fill_(0) + """ see RARE paper Fig. 6 (a) """ + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) + ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) + + def forward(self, batch_I): + """ + input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] + output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] + """ + batch_size = batch_I.size(0) + features = self.conv(batch_I).view(batch_size, -1) + batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) + return batch_C_prime + + +class GridGenerator(nn.Module): + """ Grid Generator of RARE, which produces P_prime by multipling T with P """ + + def __init__(self, F, I_r_size): + """ Generate P_hat and inv_delta_C for later """ + super().__init__() + self.eps = 1e-6 + self.I_r_height, self.I_r_width = I_r_size + self.F = F + self.C = self._build_C(self.F) # F x 2 + self.P = self._build_P(self.I_r_width, self.I_r_height) + + # num_gpu = torch.cuda.device_count() + # if num_gpu > 1: + # for multi-gpu, you may need register buffer + self.register_buffer("inv_delta_C", torch.tensor( + self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 + self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 + # else: + # # for fine-tuning with different image width, you may use below instead of self.register_buffer + # self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float() # F+3 x F+3 + # self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float() # n x F+3 + + def _build_C(self, F): + """ Return coordinates of fiducial points in I_r; C """ + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = -1 * np.ones(int(F / 2)) + ctrl_pts_y_bottom = np.ones(int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + return C # F x 2 + + def _build_inv_delta_C(self, F, C): + """ Return inv_delta_C which is needed to calculate T """ + hat_C = np.zeros((F, F), dtype=float) # F x F + for i in range(0, F): + for j in range(i, F): + r = np.linalg.norm(C[i] - C[j]) + hat_C[i, j] = r + hat_C[j, i] = r + np.fill_diagonal(hat_C, 1) + hat_C = (hat_C ** 2) * np.log(hat_C) + # print(C.shape, hat_C.shape) + delta_C = np.concatenate( # F+3 x F+3 + [ + np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 + np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 + np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 + ], + axis=0 + ) + inv_delta_C = np.linalg.inv(delta_C) + return inv_delta_C # F+3 x F+3 + + def _build_P(self, I_r_width, I_r_height): + I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width + I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height + P = np.stack( # self.I_r_width x self.I_r_height x 2 + np.meshgrid(I_r_grid_x, I_r_grid_y), + axis=2 + ) + return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 + + def _build_P_hat(self, F, C, P): + n = P.shape[0] # n (= self.I_r_width x self.I_r_height) + P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 + C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 + P_diff = P_tile - C_tile # n x F x 2 + rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F + rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F + P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) + return P_hat # n x F+3 + + def build_P_prime(self, batch_C_prime): + """ Generate Grid from batch_C_prime [batch_size x F x 2] """ + batch_size = batch_C_prime.size(0) + batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) + batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) + batch_C_prime_with_zeros = torch.cat((batch_C_prime, batch_C_prime.new_zeros( + (batch_size, 3, 2), dtype=torch.float)), dim=1) # batch_size x F+3 x 2 + batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 + batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 + return batch_P_prime # batch_size x n x 2 diff --git a/strhub/models/utils.py b/strhub/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f53ddca31f10366840d00b7f82a44471dce089df --- /dev/null +++ b/strhub/models/utils.py @@ -0,0 +1,123 @@ +from pathlib import PurePath +from typing import Sequence + +import torch +from torch import nn + +import yaml + + +class InvalidModelError(RuntimeError): + """Exception raised for any model-related error (creation, loading)""" + + +_WEIGHTS_URL = { + 'parseq-tiny': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq_tiny-e7a21b54.pt', + 'parseq': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq-bb5792a6.pt', + 'abinet': 'https://github.com/baudm/parseq/releases/download/v1.0.0/abinet-1d1e373e.pt', + 'trba': 'https://github.com/baudm/parseq/releases/download/v1.0.0/trba-cfaed284.pt', + 'vitstr': 'https://github.com/baudm/parseq/releases/download/v1.0.0/vitstr-26d0fcf4.pt', + 'crnn': 'https://github.com/baudm/parseq/releases/download/v1.0.0/crnn-679d0e31.pt', +} + + +def _get_config(experiment: str, **kwargs): + """Emulates hydra config resolution""" + root = PurePath(__file__).parents[2] + with open(root / 'configs/main.yaml', 'r') as f: + config = yaml.load(f, yaml.Loader)['model'] + with open(root / f'configs/charset/94_full.yaml', 'r') as f: + config.update(yaml.load(f, yaml.Loader)['model']) + with open(root / f'configs/experiment/{experiment}.yaml', 'r') as f: + exp = yaml.load(f, yaml.Loader) + # Apply base model config + model = exp['defaults'][0]['override /model'] + with open(root / f'configs/model/{model}.yaml', 'r') as f: + config.update(yaml.load(f, yaml.Loader)) + # Apply experiment config + if 'model' in exp: + config.update(exp['model']) + config.update(kwargs) + # Workaround for now: manually cast the lr to the correct type. + config['lr'] = float(config['lr']) + return config + + +def _get_model_class(key): + if 'abinet' in key: + from .abinet.system import ABINet as ModelClass + elif 'crnn' in key: + from .crnn.system import CRNN as ModelClass + elif 'parseq' in key: + from .parseq.system import PARSeq as ModelClass + elif 'trba' in key: + from .trba.system import TRBA as ModelClass + elif 'trbc' in key: + from .trba.system import TRBC as ModelClass + elif 'vitstr' in key: + from .vitstr.system import ViTSTR as ModelClass + else: + raise InvalidModelError("Unable to find model class for '{}'".format(key)) + return ModelClass + + +def get_pretrained_weights(experiment): + try: + url = _WEIGHTS_URL[experiment] + except KeyError: + raise InvalidModelError("No pretrained weights found for '{}'".format(experiment)) from None + return torch.hub.load_state_dict_from_url(url=url, map_location='cpu', check_hash=True) + + +def create_model(experiment: str, pretrained: bool = False, **kwargs): + try: + config = _get_config(experiment, **kwargs) + except FileNotFoundError: + raise InvalidModelError("No configuration found for '{}'".format(experiment)) from None + ModelClass = _get_model_class(experiment) + model = ModelClass(**config) + if pretrained: + model.load_state_dict(get_pretrained_weights(experiment)) + return model + + +def load_from_checkpoint(checkpoint_path: str, **kwargs): + if checkpoint_path.startswith('pretrained='): + model_id = checkpoint_path.split('=', maxsplit=1)[1] + model = create_model(model_id, True, **kwargs) + else: + ModelClass = _get_model_class(checkpoint_path) + model = ModelClass.load_from_checkpoint(checkpoint_path, **kwargs) + return model + + +def parse_model_args(args): + kwargs = {} + arg_types = {t.__name__: t for t in [int, float, str]} + arg_types['bool'] = lambda v: v.lower() == 'true' # special handling for bool + for arg in args: + name, value = arg.split('=', maxsplit=1) + name, arg_type = name.split(':', maxsplit=1) + kwargs[name] = arg_types[arg_type](value) + return kwargs + + +def init_weights(module: nn.Module, name: str = '', exclude: Sequence[str] = ()): + """Initialize the weights using the typical initialization schemes used in SOTA models.""" + if any(map(name.startswith, exclude)): + return + if isinstance(module, nn.Linear): + nn.init.trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.trunc_normal_(module.weight, std=.02) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) diff --git a/strhub/models/vitstr/.ipynb_checkpoints/system-checkpoint.py b/strhub/models/vitstr/.ipynb_checkpoints/system-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..f5cedc4d2e1ef08430df743c42150c5cc84220dc --- /dev/null +++ b/strhub/models/vitstr/.ipynb_checkpoints/system-checkpoint.py @@ -0,0 +1,58 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Any, Optional + +import torch +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch import Tensor + +from strhub.models.base import CrossEntropySystem +from strhub.models.utils import init_weights +from .model import ViTSTR as Model + + +class ViTSTR(CrossEntropySystem): + + def __init__(self, charset_train: str, charset_test: str, max_label_length: int, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float, + img_size: Sequence[int], patch_size: Sequence[int], embed_dim: int, num_heads: int, + **kwargs: Any) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + self.max_label_length = max_label_length + # We don't predict nor + self.model = Model(img_size=img_size, patch_size=patch_size, depth=12, mlp_ratio=4, qkv_bias=True, + embed_dim=embed_dim, num_heads=num_heads, num_classes=len(self.tokenizer) - 2) + # Non-zero weight init for the head + self.model.head.apply(init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {'model.' + n for n in self.model.no_weight_decay()} + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + logits = self.model.forward(images, max_length + 2) # +2 tokens for [GO] and [s] + # Truncate to conform to other models. [GO] in ViTSTR is actually used as the padding (therefore, ignored). + # First position corresponds to the class token, which is unused and ignored in the original work. + logits = logits[:, 1:] + return logits + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + loss = self.forward_logits_loss(images, labels)[1] + self.log('loss', loss) + return loss diff --git a/strhub/models/vitstr/__init__.py b/strhub/models/vitstr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19e985679da1fcaa6deb306697993fd601892d6c --- /dev/null +++ b/strhub/models/vitstr/__init__.py @@ -0,0 +1,12 @@ +r""" +Atienza, Rowel. "Vision Transformer for Fast and Efficient Scene Text Recognition." +In International Conference on Document Analysis and Recognition (ICDAR). 2021. + +https://arxiv.org/abs/2105.08582 + +All source files, except `system.py`, are based on the implementation listed below, +and hence are released under the license of the original. + +Source: https://github.com/roatienza/deep-text-recognition-benchmark +License: Apache License 2.0 (see LICENSE file in project root) +""" diff --git a/strhub/models/vitstr/model.py b/strhub/models/vitstr/model.py new file mode 100644 index 0000000000000000000000000000000000000000..62c5d551626c325243a4f0d055869384a59b3910 --- /dev/null +++ b/strhub/models/vitstr/model.py @@ -0,0 +1,28 @@ +""" +Implementation of ViTSTR based on timm VisionTransformer. + +TODO: +1) distilled deit backbone +2) base deit backbone + +Copyright 2021 Rowel Atienza +""" + +from timm.models.vision_transformer import VisionTransformer + + +class ViTSTR(VisionTransformer): + """ + ViTSTR is basically a ViT that uses DeiT weights. + Modified head to support a sequence of characters prediction for STR. + """ + + def forward(self, x, seqlen: int = 25): + x = self.forward_features(x) + x = x[:, :seqlen] + + # batch, seqlen, embsize + b, s, e = x.size() + x = x.reshape(b * s, e) + x = self.head(x).view(b, s, self.num_classes) + return x diff --git a/strhub/models/vitstr/system.py b/strhub/models/vitstr/system.py new file mode 100644 index 0000000000000000000000000000000000000000..f5cedc4d2e1ef08430df743c42150c5cc84220dc --- /dev/null +++ b/strhub/models/vitstr/system.py @@ -0,0 +1,58 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Any, Optional + +import torch +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch import Tensor + +from strhub.models.base import CrossEntropySystem +from strhub.models.utils import init_weights +from .model import ViTSTR as Model + + +class ViTSTR(CrossEntropySystem): + + def __init__(self, charset_train: str, charset_test: str, max_label_length: int, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float, + img_size: Sequence[int], patch_size: Sequence[int], embed_dim: int, num_heads: int, + **kwargs: Any) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + self.max_label_length = max_label_length + # We don't predict nor + self.model = Model(img_size=img_size, patch_size=patch_size, depth=12, mlp_ratio=4, qkv_bias=True, + embed_dim=embed_dim, num_heads=num_heads, num_classes=len(self.tokenizer) - 2) + # Non-zero weight init for the head + self.model.head.apply(init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {'model.' + n for n in self.model.no_weight_decay()} + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + logits = self.model.forward(images, max_length + 2) # +2 tokens for [GO] and [s] + # Truncate to conform to other models. [GO] in ViTSTR is actually used as the padding (therefore, ignored). + # First position corresponds to the class token, which is unused and ignored in the original work. + logits = logits[:, 1:] + return logits + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + loss = self.forward_logits_loss(images, labels)[1] + self.log('loss', loss) + return loss