diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..da01597b0e35d851054d9d4b780706a1afc11ab5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,176 @@ +# Output directories +outputs/ +multirun/ +ray_results/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +# requirements/core.*.txt +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +.python-version + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# IDE +.idea/ + +########## CUSTOM FOLDER ############## +README_original.md + +results/ +images +bharatSTR/East/tmp +bharatSTR/models +bharatSTR/images +__pycache__/ +bharatSTR/ + +IndicPhotoOCR/detection/East +IndicPhotoOCR/recognition/models + +IndicPhotoOCR/script_identification/images +IndicPhotoOCR/script_identification/models + + +build/ +dist/ +test.png +static/pics/IndicPhotoOCR.gif +input_image.jpg +output_image.png + + + diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..aaa81c57c8c05ec7968a7d6871f4cdeb75401105 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,45 @@ +# Use NVIDIA PyTorch as the base image +FROM nvcr.io/nvidia/pytorch:23.12-py3 + +# Install additional dependencies +RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 + +# Set environment variables for Miniconda and Conda environment +ENV CONDA_DIR /opt/conda +ENV PATH $CONDA_DIR/bin:$PATH + +# Install Miniconda +RUN apt-get update && apt-get install -y wget && \ + wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + bash Miniconda3-latest-Linux-x86_64.sh -b -p $CONDA_DIR && \ + rm Miniconda3-latest-Linux-x86_64.sh + +# Create a new Conda environment named "bocr" with Python 3.9 +RUN conda create -n bocr python=3.9 -y + +# Initialize conda +RUN conda init + +# Reload the env configs +RUN source ~/.bashrc + +# Make RUN commands use the bocr environment +SHELL ["conda", "run", "-n", "bocr", "/bin/bash", "-c"] + +# # Set default shell to bash +# SHELL ["/bin/bash", "-c"] + +# # Clone BharatOCR repository +# RUN git clone https://github.com/Bhashini-IITJ/BharatOCR.git && \ +# git switch photoOCR && \ +# cd IndicPhotoOCR && \ +# python setup.py sdist bdist_wheel && \ +# pip install ./dist/IndicPhotoOCR-1.1.0-py3-none-any.whl[cu118] --extra-index-url https://download.pytorch.org/whl/cu118 + +# # # Set default command to run BharatOCR +# CMD ["conda", "run", "-n", "bocr", "python", "-m", "IndicPhotoOCR.ocr"] + + +# cd IndicPhotoOCR +# sudo docker build -t indicphotoocr:latest . +# sudo docker run --gpus all --rm -it indicphotoocr:latest \ No newline at end of file diff --git a/IndicPhotoOCR/__init__.py b/IndicPhotoOCR/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/IndicPhotoOCR/detection/__init__.py b/IndicPhotoOCR/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/IndicPhotoOCR/detection/east_config.py b/IndicPhotoOCR/detection/east_config.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c157236b45de4c590d5c6e7cc147a2776f7af6 --- /dev/null +++ b/IndicPhotoOCR/detection/east_config.py @@ -0,0 +1,39 @@ + +# data-config +import numpy as np + +train_data_path = './dataset/train/' +train_batch_size_per_gpu = 14 # 14 +num_workers = 24 # 24 +gpu_ids = [0] # [0,1,2,3] +gpu = 1 # 4 +input_size = 512 # 预处理后归一化后图像尺寸 +background_ratio = 3. / 8 # 纯背景样本比例 +random_scale = np.array([0.5, 1, 2.0, 3.0]) # 提取多尺度图片信息 +geometry = 'RBOX' # 选择使用几何特征图类型 +max_image_large_side = 1280 +max_text_size = 800 +min_text_size = 10 +min_crop_side_ratio = 0.1 +means=[100, 100, 100] +pretrained = True # 是否加载基础网络的预训练模型 +pretrained_basemodel_path = 'IndicPhotoOCR/detection/East/tmp/backbone_net/mobilenet_v2.pth.tar' +pre_lr = 1e-4 # 基础网络的初始学习率 +lr = 1e-3 # 后面网络的初始学习率 +decay_steps = 50 # decayed_learning_rate = learning_rate * decay_rate ^ (global_epoch / decay_steps) +decay_rate = 0.97 +init_type = 'xavier' # 网络参数初始化方式 +resume = True # 整体网络是否恢复原来保存的模型 +checkpoint = 'IndicPhotoOCR/detection/East/tmp/epoch_990_checkpoint.pth.tar' # 指定具体路径及文件名 +max_epochs = 1000 # 最大迭代epochs数 +l2_weight_decay = 1e-6 # l2正则化惩罚项权重 +print_freq = 10 # 每10个batch输出损失结果 +save_eval_iteration = 50 # 每10个epoch保存一次模型,并做一次评价 +save_model_path = './tmp/' # 模型保存路径 +test_img_path = './dataset/full_set' # demo测试样本路径'./demo/test_img/',数据集测试为'./dataset/test/' +res_img_path = 'results' # demo结果存放路径'./demo/result_img/',数据集测试为 './dataset/test_result/' +write_images = True # 是否输出图像结果 +score_map_thresh = 0.8 # 置信度阈值 +box_thresh = 0.1 # 文本框中置信度平均值的阈值 +nms_thres = 0.2 # 局部非极大抑制IOU阈值 +compute_hmean_path = './dataset/test_compute_hmean/' diff --git a/IndicPhotoOCR/detection/east_detector.py b/IndicPhotoOCR/detection/east_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..33478de321e23bdb21c46c1c6cdc8a5ae426778d --- /dev/null +++ b/IndicPhotoOCR/detection/east_detector.py @@ -0,0 +1,87 @@ +import os +import torch +import cv2 +import numpy as np +import time +import warnings + + +import IndicPhotoOCR.detection.east_config as cfg +from IndicPhotoOCR.detection.east_utils import ModelManager +from IndicPhotoOCR.detection.east_model import East +import IndicPhotoOCR.detection.east_utils as utils + +# Suppress warnings +warnings.filterwarnings("ignore") + +class EASTdetector: + def __init__(self, model_name= "east", model_path=None): + self.model_path = model_path + # self.model_manager = ModelManager() + # self.model_manager.ensure_model(model_name) + # self.ensure_model(self.model_name) + # self.root_model_dir = "BharatSTR/bharatOCR/detection/East/tmp" + + def detect(self, image_path, model_checkpoint, device): + # Load image + im = cv2.imread(image_path) + # im = cv2.imread(image_path)[:, :, ::-1] + + # Initialize the EAST model and load checkpoint + model = East() + model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids) + + # Load the model checkpoint with weights_only=True + checkpoint = torch.load(model_checkpoint, map_location=torch.device(device), weights_only=True) + model.load_state_dict(checkpoint['state_dict']) + model.eval() + + # Resize image and convert to tensor format + im_resized, (ratio_h, ratio_w) = utils.resize_image(im) + im_resized = im_resized.astype(np.float32).transpose(2, 0, 1) + im_tensor = torch.from_numpy(im_resized).unsqueeze(0).cpu() + + # Inference + timer = {'net': 0, 'restore': 0, 'nms': 0} + start = time.time() + score, geometry = model(im_tensor) + timer['net'] = time.time() - start + + # Process output + score = score.permute(0, 2, 3, 1).data.cpu().numpy() + geometry = geometry.permute(0, 2, 3, 1).data.cpu().numpy() + + # Detect boxes + boxes, timer = utils.detect( + score_map=score, geo_map=geometry, timer=timer, + score_map_thresh=cfg.score_map_thresh, box_thresh=cfg.box_thresh, + nms_thres=cfg.box_thresh + ) + bbox_result_dict = {'detections': []} + + # Parse detected boxes and adjust coordinates + if boxes is not None: + boxes = boxes[:, :8].reshape((-1, 4, 2)) + boxes[:, :, 0] /= ratio_w + boxes[:, :, 1] /= ratio_h + for box in boxes: + box = utils.sort_poly(box.astype(np.int32)) + if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5: + continue + bbox_result_dict['detections'].append([ + [int(coord[0]), int(coord[1])] for coord in box + ]) + + return bbox_result_dict + +# if __name__ == "__main__": +# import argparse +# parser = argparse.ArgumentParser(description='Text detection using EAST model') +# parser.add_argument('--image_path', type=str, required=True, help='Path to the input image') +# parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"') +# parser.add_argument('--model_checkpoint', type=str, required=True, help='Path to the model checkpoint file') +# args = parser.parse_args() + +# # Run prediction and get results as dictionary +# detection_result = predict(args.image_path, args.device, args.model_checkpoint) +# print(detection_result) diff --git a/IndicPhotoOCR/detection/east_locality_aware_nms.py b/IndicPhotoOCR/detection/east_locality_aware_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..63eea757d4dde10d6dc2ab216c150fd5f7cdaec2 --- /dev/null +++ b/IndicPhotoOCR/detection/east_locality_aware_nms.py @@ -0,0 +1,75 @@ + +import numpy as np +from shapely.geometry import Polygon + + +def intersection(g, p): + g = Polygon(g[:8].reshape((4, 2))) + p = Polygon(p[:8].reshape((4, 2))) + if not g.is_valid or not p.is_valid: + return 0 + inter = Polygon(g).intersection(Polygon(p)).area + union = g.area + p.area - inter + if union == 0: + return 0 + else: + return inter/union + + +def weighted_merge(g, p): + # g[0]=min(g[0],p[0]) + # g[1] = min(g[1], p[1]) + # g[4] = max(g[4], p[4]) + # g[5]= max(g[5],p[5]) + # + # g[2] = max(g[2], p[2]) + # g[3] = min(g[3], p[3]) + # g[6] = min(g[6], p[6]) + # g[7] = max(g[7], p[7]) + + g[:8] = (g[8] * g[:8] + p[8] * p[:8])/(g[8] + p[8]) + g[8] = (g[8] + p[8]) + return g + + +def standard_nms(S, thres): + order = np.argsort(S[:, 8])[::-1] + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + ovr = np.array([intersection(S[i], S[t]) for t in order[1:]]) + + inds = np.where(ovr <= thres)[0] + order = order[inds+1] + + return S[keep] + + +def nms_locality(polys, thres=0.3): + ''' + locality aware nms of EAST + :param polys: a N*9 numpy array. first 8 coordinates, then prob + :return: boxes after nms + ''' + S = [] + p = None + for g in polys: + if p is not None and intersection(g, p) > thres: + p = weighted_merge(g, p) + else: + if p is not None: + S.append(p) + p = g + if p is not None: + S.append(p) + + if len(S) == 0: + return np.array([]) + return standard_nms(np.array(S), thres) + + +if __name__ == '__main__': + # 343,350,448,135,474,143,369,359 + print(Polygon(np.array([[343, 350], [448, 135], + [474, 143], [369, 359]])).area) diff --git a/IndicPhotoOCR/detection/east_model.py b/IndicPhotoOCR/detection/east_model.py new file mode 100644 index 0000000000000000000000000000000000000000..912bb2df9bcafc068811cdbf3b81864754f5a990 --- /dev/null +++ b/IndicPhotoOCR/detection/east_model.py @@ -0,0 +1,242 @@ + +import torch.nn as nn +import math +import torch + + +from IndicPhotoOCR.detection import east_config as cfg +from IndicPhotoOCR.detection import east_utils as utils + + +def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = round(inp * expand_ratio) + self.use_res_connect = self.stride == 1 and inp == oup + + if expand_ratio == 1: + self.conv = nn.Sequential( + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + else: + self.conv = nn.Sequential( + # pw + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, width_mult=1.): + super(MobileNetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + interverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + # [6, 320, 1, 1], + ] + + # building first layer + # assert input_size % 32 == 0 + input_channel = int(input_channel * width_mult) + self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel + self.features = [conv_bn(3, input_channel, 2)] + # building inverted residual blocks + for t, c, n, s in interverted_residual_setting: + output_channel = int(c * width_mult) + for i in range(n): + if i == 0: + self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) + else: + self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) + input_channel = output_channel + + # make it nn.Sequential + self.features = nn.Sequential(*self.features) + + self._initialize_weights() + + def forward(self, x): + x = self.features(x) + # x = x.mean(3).mean(2) + # x = self.classifier(x) + return x + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + +def mobilenet(pretrained=True, **kwargs): + """ + Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = MobileNetV2() + if pretrained: + model_dict = model.state_dict() + pretrained_dict = torch.load(cfg.pretrained_basemodel_path,map_location=torch.device('cpu'), weights_only=True) + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + # state_dict = torch.load(cfg.pretrained_basemodel_path) # add map_location='cpu' if no gpu + # model.load_state_dict(state_dict) + + return model + + +class East(nn.Module): + def __init__(self): + super(East, self).__init__() + self.mobilenet = mobilenet(True) + # self.si for stage i + self.s1 = nn.Sequential(*list(self.mobilenet.children())[0][0:4]) + self.s2 = nn.Sequential(*list(self.mobilenet.children())[0][4:7]) + self.s3 = nn.Sequential(*list(self.mobilenet.children())[0][7:14]) + self.s4 = nn.Sequential(*list(self.mobilenet.children())[0][14:17]) + + self.conv1 = nn.Conv2d(160+96, 128, 1) + self.bn1 = nn.BatchNorm2d(128) + self.relu1 = nn.ReLU() + + self.conv2 = nn.Conv2d(128, 128, 3, padding=1) + self.bn2 = nn.BatchNorm2d(128) + self.relu2 = nn.ReLU() + + self.conv3 = nn.Conv2d(128+32, 64, 1) + self.bn3 = nn.BatchNorm2d(64) + self.relu3 = nn.ReLU() + + self.conv4 = nn.Conv2d(64, 64, 3, padding=1) + self.bn4 = nn.BatchNorm2d(64) + self.relu4 = nn.ReLU() + + self.conv5 = nn.Conv2d(64+24, 64, 1) + self.bn5 = nn.BatchNorm2d(64) + self.relu5 = nn.ReLU() + + self.conv6 = nn.Conv2d(64, 32, 3, padding=1) + self.bn6 = nn.BatchNorm2d(32) + self.relu6 = nn.ReLU() + + self.conv7 = nn.Conv2d(32, 32, 3, padding=1) + self.bn7 = nn.BatchNorm2d(32) + self.relu7 = nn.ReLU() + + self.conv8 = nn.Conv2d(32, 1, 1) + self.sigmoid1 = nn.Sigmoid() + self.conv9 = nn.Conv2d(32, 4, 1) + self.sigmoid2 = nn.Sigmoid() + self.conv10 = nn.Conv2d(32, 1, 1) + self.sigmoid3 = nn.Sigmoid() + self.unpool1 = nn.Upsample(scale_factor=2, mode='bilinear') + self.unpool2 = nn.Upsample(scale_factor=2, mode='bilinear') + self.unpool3 = nn.Upsample(scale_factor=2, mode='bilinear') + + # utils.init_weights([self.conv1,self.conv2,self.conv3,self.conv4, + # self.conv5,self.conv6,self.conv7,self.conv8, + # self.conv9,self.conv10,self.bn1,self.bn2, + # self.bn3,self.bn4,self.bn5,self.bn6,self.bn7]) + + def forward(self, images): + images = utils.mean_image_subtraction(images) + + f0 = self.s1(images) + f1 = self.s2(f0) + f2 = self.s3(f1) + f3 = self.s4(f2) + + # _, f = self.mobilenet(images) + h = f3 # bs 2048 w/32 h/32 + g = (self.unpool1(h)) # bs 2048 w/16 h/16 + c = self.conv1(torch.cat((g, f2), 1)) + c = self.bn1(c) + c = self.relu1(c) + + h = self.conv2(c) # bs 128 w/16 h/16 + h = self.bn2(h) + h = self.relu2(h) + g = self.unpool2(h) # bs 128 w/8 h/8 + c = self.conv3(torch.cat((g, f1), 1)) + c = self.bn3(c) + c = self.relu3(c) + + h = self.conv4(c) # bs 64 w/8 h/8 + h = self.bn4(h) + h = self.relu4(h) + g = self.unpool3(h) # bs 64 w/4 h/4 + c = self.conv5(torch.cat((g, f0), 1)) + c = self.bn5(c) + c = self.relu5(c) + + h = self.conv6(c) # bs 32 w/4 h/4 + h = self.bn6(h) + h = self.relu6(h) + g = self.conv7(h) # bs 32 w/4 h/4 + g = self.bn7(g) + g = self.relu7(g) + + F_score = self.conv8(g) # bs 1 w/4 h/4 + F_score = self.sigmoid1(F_score) + geo_map = self.conv9(g) + geo_map = self.sigmoid2(geo_map) * 512 + angle_map = self.conv10(g) + angle_map = self.sigmoid3(angle_map) + angle_map = (angle_map - 0.5) * math.pi / 2 + + F_geometry = torch.cat((geo_map, angle_map), 1) # bs 5 w/4 h/4 + + return F_score, F_geometry + + +model=East() diff --git a/IndicPhotoOCR/detection/east_preprossing.py b/IndicPhotoOCR/detection/east_preprossing.py new file mode 100644 index 0000000000000000000000000000000000000000..aa7eeffbfa163cf133736dae9b85967f2432f78e --- /dev/null +++ b/IndicPhotoOCR/detection/east_preprossing.py @@ -0,0 +1,681 @@ + +# coding:utf-8 +import glob +import csv +import cv2 +import os +import numpy as np +from shapely.geometry import Polygon + + +from IndicPhotoOCR.detection import east_config as cfg +from IndicPhotoOCR.detection import east_utils + + +def get_images(img_root): + files = [] + for ext in ['jpg']: + files.extend(glob.glob( + os.path.join(img_root, '*.{}'.format(ext)))) + # print(glob.glob( + # os.path.join(FLAGS.training_data_path, '*.{}'.format(ext)))) + return files + + +def load_annoataion(p): + ''' + load annotation from the text file + :param p: + :return: + ''' + text_polys = [] + text_tags = [] + if not os.path.exists(p): + return np.array(text_polys, dtype=np.float32) + with open(p, 'r', encoding='UTF-8') as f: + reader = csv.reader(f) + for line in reader: + label = line[-1] + # strip BOM. \ufeff for python3, \xef\xbb\bf for python2 + line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in line] + + x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8])) + text_polys.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) + # print(text_polys) + if label == '*' or label == '###': + text_tags.append(True) + else: + text_tags.append(False) + return np.array(text_polys, dtype=np.float32), np.array(text_tags, dtype=np.bool) + + +def polygon_area(poly): + ''' + compute area of a polygon + :param poly: + :return: + ''' + edge = [ + (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), + (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), + (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), + (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]) + ] + return np.sum(edge) / 2. + + +def check_and_validate_polys(polys, tags, xxx_todo_changeme): + ''' + check so that the text poly is in the same direction, + and also filter some invalid polygons + :param polys: + :param tags: + :return: + ''' + (h, w) = xxx_todo_changeme + if polys.shape[0] == 0: + return polys + polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) + polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) + + validated_polys = [] + validated_tags = [] + + # 判断四边形的点时针方向,以及是否是有效四边形 + for poly, tag in zip(polys, tags): + p_area = polygon_area(poly) + if abs(p_area) < 1: + # print poly + print('invalid poly') + continue + if p_area > 0: + print('poly in wrong direction') + poly = poly[(0, 3, 2, 1), :] + validated_polys.append(poly) + validated_tags.append(tag) + return np.array(validated_polys), np.array(validated_tags) + + +def crop_area(im, polys, tags, crop_background=False, max_tries=100): + ''' + make random crop from the input image + :param im: + :param polys: + :param tags: + :param crop_background: + :param max_tries: + :return: + ''' + h, w, _ = im.shape + pad_h = h // 10 + pad_w = w // 10 + h_array = np.zeros((h + pad_h * 2), dtype=np.int32) + w_array = np.zeros((w + pad_w * 2), dtype=np.int32) + for poly in polys: + poly = np.round(poly, decimals=0).astype(np.int32) + minx = np.min(poly[:, 0]) + maxx = np.max(poly[:, 0]) + w_array[minx + pad_w:maxx + pad_w] = 1 + miny = np.min(poly[:, 1]) + maxy = np.max(poly[:, 1]) + h_array[miny + pad_h:maxy + pad_h] = 1 + # ensure the cropped area not across a text,保证裁剪区域不能与文本交叉 + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + if len(h_axis) == 0 or len(w_axis) == 0: + return im, polys, tags + for i in range(max_tries): # 试验50次 + xx = np.random.choice(w_axis, size=2) + xmin = np.min(xx) - pad_w + xmax = np.max(xx) - pad_w + xmin = np.clip(xmin, 0, w - 1) + xmax = np.clip(xmax, 0, w - 1) + yy = np.random.choice(h_axis, size=2) + ymin = np.min(yy) - pad_h + ymax = np.max(yy) - pad_h + ymin = np.clip(ymin, 0, h - 1) + ymax = np.clip(ymax, 0, h - 1) + if xmax - xmin < cfg.min_crop_side_ratio * w or ymax - ymin < cfg.min_crop_side_ratio * h: + # area too small + continue + if polys.shape[0] != 0: + poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \ + & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax) + selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0] + else: + selected_polys = [] + if len(selected_polys) == 0: + # no text in this area + if crop_background: + return im[ymin:ymax + 1, xmin:xmax + 1, :], polys[selected_polys], tags[selected_polys] + else: + continue + im = im[ymin:ymax + 1, xmin:xmax + 1, :] + polys = polys[selected_polys] + tags = tags[selected_polys] + polys[:, :, 0] -= xmin + polys[:, :, 1] -= ymin + return im, polys, tags + + return im, polys, tags + + +def shrink_poly(poly, r): + ''' + fit a poly inside the origin poly, maybe bugs here... + used for generate the score map + :param poly: the text poly + :param r: r in the paper + :return: the shrinked poly + ''' + # shrink ratio + R = 0.3 + # find the longer pair + if np.linalg.norm(poly[0] - poly[1]) + np.linalg.norm(poly[2] - poly[3]) > \ + np.linalg.norm(poly[0] - poly[3]) + np.linalg.norm(poly[1] - poly[2]): + # first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2) + ## p0, p1 + theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0])) + poly[0][0] += R * r[0] * np.cos(theta) + poly[0][1] += R * r[0] * np.sin(theta) + poly[1][0] -= R * r[1] * np.cos(theta) + poly[1][1] -= R * r[1] * np.sin(theta) + ## p2, p3 + theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0])) + poly[3][0] += R * r[3] * np.cos(theta) + poly[3][1] += R * r[3] * np.sin(theta) + poly[2][0] -= R * r[2] * np.cos(theta) + poly[2][1] -= R * r[2] * np.sin(theta) + ## p0, p3 + theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1])) + poly[0][0] += R * r[0] * np.sin(theta) + poly[0][1] += R * r[0] * np.cos(theta) + poly[3][0] -= R * r[3] * np.sin(theta) + poly[3][1] -= R * r[3] * np.cos(theta) + ## p1, p2 + theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1])) + poly[1][0] += R * r[1] * np.sin(theta) + poly[1][1] += R * r[1] * np.cos(theta) + poly[2][0] -= R * r[2] * np.sin(theta) + poly[2][1] -= R * r[2] * np.cos(theta) + else: + ## p0, p3 + # print poly + theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1])) + poly[0][0] += R * r[0] * np.sin(theta) + poly[0][1] += R * r[0] * np.cos(theta) + poly[3][0] -= R * r[3] * np.sin(theta) + poly[3][1] -= R * r[3] * np.cos(theta) + ## p1, p2 + theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1])) + poly[1][0] += R * r[1] * np.sin(theta) + poly[1][1] += R * r[1] * np.cos(theta) + poly[2][0] -= R * r[2] * np.sin(theta) + poly[2][1] -= R * r[2] * np.cos(theta) + ## p0, p1 + theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0])) + poly[0][0] += R * r[0] * np.cos(theta) + poly[0][1] += R * r[0] * np.sin(theta) + poly[1][0] -= R * r[1] * np.cos(theta) + poly[1][1] -= R * r[1] * np.sin(theta) + ## p2, p3 + theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0])) + poly[3][0] += R * r[3] * np.cos(theta) + poly[3][1] += R * r[3] * np.sin(theta) + poly[2][0] -= R * r[2] * np.cos(theta) + poly[2][1] -= R * r[2] * np.sin(theta) + return poly + + +# def point_dist_to_line(p1, p2, p3): +# # compute the distance from p3 to p1-p2 +# return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1) + + +# 点p3到直线p12的距离 +def point_dist_to_line(p1, p2, p3): + # compute the distance from p3 to p1-p2 + # return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1) + a = np.linalg.norm(p1 - p2) + b = np.linalg.norm(p2 - p3) + c = np.linalg.norm(p3 - p1) + s = (a + b + c) / 2.0 + area = np.abs((s * (s - a) * (s - b) * (s - c))) ** 0.5 + if a < 1.0: + return (b + c) / 2.0 + return 2 * area / a + + +def fit_line(p1, p2): + # fit a line ax+by+c = 0 + if p1[0] == p1[1]: + return [1., 0., -p1[0]] + else: + [k, b] = np.polyfit(p1, p2, deg=1) + return [k, -1., b] + + +def line_cross_point(line1, line2): + # line1 0= ax+by+c, compute the cross point of line1 and line2 + if line1[0] != 0 and line1[0] == line2[0]: + print('Cross point does not exist') + return None + if line1[0] == 0 and line2[0] == 0: + print('Cross point does not exist') + return None + if line1[1] == 0: + x = -line1[2] + y = line2[0] * x + line2[2] + elif line2[1] == 0: + x = -line2[2] + y = line1[0] * x + line1[2] + else: + k1, _, b1 = line1 + k2, _, b2 = line2 + x = -(b1 - b2) / (k1 - k2) + y = k1 * x + b1 + return np.array([x, y], dtype=np.float32) + + +def line_verticle(line, point): + # get the verticle line from line across point + if line[1] == 0: + verticle = [0, -1, point[1]] + else: + if line[0] == 0: + verticle = [1, 0, -point[0]] + else: + verticle = [-1. / line[0], -1, point[1] - (-1 / line[0] * point[0])] + return verticle + + +def rectangle_from_parallelogram(poly): + ''' + fit a rectangle from a parallelogram + :param poly: + :return: + ''' + p0, p1, p2, p3 = poly + angle_p0 = np.arccos(np.dot(p1 - p0, p3 - p0) / (np.linalg.norm(p0 - p1) * np.linalg.norm(p3 - p0))) + if angle_p0 < 0.5 * np.pi: + if np.linalg.norm(p0 - p1) > np.linalg.norm(p0 - p3): + # p0 and p2 + ## p0 + p2p3 = fit_line([p2[0], p3[0]], [p2[1], p3[1]]) + p2p3_verticle = line_verticle(p2p3, p0) + + new_p3 = line_cross_point(p2p3, p2p3_verticle) + ## p2 + p0p1 = fit_line([p0[0], p1[0]], [p0[1], p1[1]]) + p0p1_verticle = line_verticle(p0p1, p2) + + new_p1 = line_cross_point(p0p1, p0p1_verticle) + return np.array([p0, new_p1, p2, new_p3], dtype=np.float32) + else: + p1p2 = fit_line([p1[0], p2[0]], [p1[1], p2[1]]) + p1p2_verticle = line_verticle(p1p2, p0) + + new_p1 = line_cross_point(p1p2, p1p2_verticle) + p0p3 = fit_line([p0[0], p3[0]], [p0[1], p3[1]]) + p0p3_verticle = line_verticle(p0p3, p2) + + new_p3 = line_cross_point(p0p3, p0p3_verticle) + return np.array([p0, new_p1, p2, new_p3], dtype=np.float32) + else: + if np.linalg.norm(p0 - p1) > np.linalg.norm(p0 - p3): + # p1 and p3 + ## p1 + p2p3 = fit_line([p2[0], p3[0]], [p2[1], p3[1]]) + p2p3_verticle = line_verticle(p2p3, p1) + + new_p2 = line_cross_point(p2p3, p2p3_verticle) + ## p3 + p0p1 = fit_line([p0[0], p1[0]], [p0[1], p1[1]]) + p0p1_verticle = line_verticle(p0p1, p3) + + new_p0 = line_cross_point(p0p1, p0p1_verticle) + return np.array([new_p0, p1, new_p2, p3], dtype=np.float32) + else: + p0p3 = fit_line([p0[0], p3[0]], [p0[1], p3[1]]) + p0p3_verticle = line_verticle(p0p3, p1) + + new_p0 = line_cross_point(p0p3, p0p3_verticle) + p1p2 = fit_line([p1[0], p2[0]], [p1[1], p2[1]]) + p1p2_verticle = line_verticle(p1p2, p3) + + new_p2 = line_cross_point(p1p2, p1p2_verticle) + return np.array([new_p0, p1, new_p2, p3], dtype=np.float32) + + +def sort_rectangle(poly): + # sort the four coordinates of the polygon, points in poly should be sorted clockwise + # First find the lowest point + p_lowest = np.argmax(poly[:, 1]) + if np.count_nonzero(poly[:, 1] == poly[p_lowest, 1]) == 2: + # 底边平行于X轴, 那么p0为左上角 - if the bottom line is parallel to x-axis, then p0 must be the upper-left corner + p0_index = np.argmin(np.sum(poly, axis=1)) + p1_index = (p0_index + 1) % 4 + p2_index = (p0_index + 2) % 4 + p3_index = (p0_index + 3) % 4 + return poly[[p0_index, p1_index, p2_index, p3_index]], 0. + else: + # 找到最低点右边的点 - find the point that sits right to the lowest point + p_lowest_right = (p_lowest - 1) % 4 + p_lowest_left = (p_lowest + 1) % 4 + angle = np.arctan( + -(poly[p_lowest][1] - poly[p_lowest_right][1]) / (poly[p_lowest][0] - poly[p_lowest_right][0])) + # assert angle > 0 + if angle <= 0: + print(angle, poly[p_lowest], poly[p_lowest_right]) + if angle / np.pi * 180 > 45: + # 这个点为p2 - this point is p2 + p2_index = p_lowest + p1_index = (p2_index - 1) % 4 + p0_index = (p2_index - 2) % 4 + p3_index = (p2_index + 1) % 4 + return poly[[p0_index, p1_index, p2_index, p3_index]], -(np.pi / 2 - angle) + else: + # 这个点为p3 - this point is p3 + p3_index = p_lowest + p0_index = (p3_index + 1) % 4 + p1_index = (p3_index + 2) % 4 + p2_index = (p3_index + 3) % 4 + return poly[[p0_index, p1_index, p2_index, p3_index]], angle + + +def restore_rectangle_rbox(origin, geometry): + d = geometry[:, :4] + angle = geometry[:, 4] + # for angle > 0 + origin_0 = origin[angle >= 0] + d_0 = d[angle >= 0] + angle_0 = angle[angle >= 0] + if origin_0.shape[0] > 0: + p = np.array([np.zeros(d_0.shape[0]), -d_0[:, 0] - d_0[:, 2], + d_0[:, 1] + d_0[:, 3], -d_0[:, 0] - d_0[:, 2], + d_0[:, 1] + d_0[:, 3], np.zeros(d_0.shape[0]), + np.zeros(d_0.shape[0]), np.zeros(d_0.shape[0]), + d_0[:, 3], -d_0[:, 2]]) + p = p.transpose((1, 0)).reshape((-1, 5, 2)) # N*5*2 + + rotate_matrix_x = np.array([np.cos(angle_0), np.sin(angle_0)]).transpose((1, 0)) + rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) # N*5*2 + + rotate_matrix_y = np.array([-np.sin(angle_0), np.cos(angle_0)]).transpose((1, 0)) + rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) + + p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis] # N*5*1 + p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis] # N*5*1 + + p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2) # N*5*2 + + p3_in_origin = origin_0 - p_rotate[:, 4, :] + new_p0 = p_rotate[:, 0, :] + p3_in_origin # N*2 + new_p1 = p_rotate[:, 1, :] + p3_in_origin + new_p2 = p_rotate[:, 2, :] + p3_in_origin + new_p3 = p_rotate[:, 3, :] + p3_in_origin + + new_p_0 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :], + new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1) # N*4*2 + else: + new_p_0 = np.zeros((0, 4, 2)) + # for angle < 0 + origin_1 = origin[angle < 0] + d_1 = d[angle < 0] + angle_1 = angle[angle < 0] + if origin_1.shape[0] > 0: + p = np.array([-d_1[:, 1] - d_1[:, 3], -d_1[:, 0] - d_1[:, 2], + np.zeros(d_1.shape[0]), -d_1[:, 0] - d_1[:, 2], + np.zeros(d_1.shape[0]), np.zeros(d_1.shape[0]), + -d_1[:, 1] - d_1[:, 3], np.zeros(d_1.shape[0]), + -d_1[:, 1], -d_1[:, 2]]) + p = p.transpose((1, 0)).reshape((-1, 5, 2)) # N*5*2 + + rotate_matrix_x = np.array([np.cos(-angle_1), -np.sin(-angle_1)]).transpose((1, 0)) + rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) # N*5*2 + + rotate_matrix_y = np.array([np.sin(-angle_1), np.cos(-angle_1)]).transpose((1, 0)) + rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) + + p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis] # N*5*1 + p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis] # N*5*1 + + p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2) # N*5*2 + + p3_in_origin = origin_1 - p_rotate[:, 4, :] + new_p0 = p_rotate[:, 0, :] + p3_in_origin # N*2 + new_p1 = p_rotate[:, 1, :] + p3_in_origin + new_p2 = p_rotate[:, 2, :] + p3_in_origin + new_p3 = p_rotate[:, 3, :] + p3_in_origin + + new_p_1 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :], + new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1) # N*4*2 + else: + new_p_1 = np.zeros((0, 4, 2)) + return np.concatenate([new_p_0, new_p_1]) + + +def restore_rectangle(origin, geometry): + return restore_rectangle_rbox(origin, geometry) + + +def generate_rbox(im_size, polys, tags): + h, w = im_size + poly_mask = np.zeros((h, w), dtype=np.uint8) + score_map = np.zeros((h, w), dtype=np.uint8) + geo_map = np.zeros((h, w, 5), dtype=np.float32) + # mask used during traning, to ignore some hard areas,用于忽略那些过小的文本 + training_mask = np.ones((h, w), dtype=np.uint8) + for poly_idx, poly_tag in enumerate(zip(polys, tags)): + poly = poly_tag[0] + tag = poly_tag[1] + + # 对每个顶点,找到经过他的两条边中较短的那条 + r = [None, None, None, None] + for i in range(4): + r[i] = min(np.linalg.norm(poly[i] - poly[(i + 1) % 4]), + np.linalg.norm(poly[i] - poly[(i - 1) % 4])) + # score map + # 放缩边框为之前的0.3倍,并对边框对应score图中的位置进行填充 + shrinked_poly = shrink_poly(poly.copy(), r).astype(np.int32)[np.newaxis, :, :] + cv2.fillPoly(score_map, shrinked_poly, 1) + cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1) + # if the poly is too small, then ignore it during training + # 如果文本框标签太小或者txt中没具体标记是什么内容,即*或者###,则加掩模,训练时忽略该部分 + poly_h = min(np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2])) + poly_w = min(np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3])) + if min(poly_h, poly_w) < cfg.min_text_size: + cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) + if tag: + cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) + + # 当前新加入的文本框区域像素点 + xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1)) + # if geometry == 'RBOX': + # 对任意两个顶点的组合生成一个平行四边形 - generate a parallelogram for any combination of two vertices + fitted_parallelograms = [] + for i in range(4): + # 选中p0和p1的连线边,生成两个平行四边形 + p0 = poly[i] + p1 = poly[(i + 1) % 4] + p2 = poly[(i + 2) % 4] + p3 = poly[(i + 3) % 4] + # 拟合ax+by+c=0 + edge = fit_line([p0[0], p1[0]], [p0[1], p1[1]]) + backward_edge = fit_line([p0[0], p3[0]], [p0[1], p3[1]]) + forward_edge = fit_line([p1[0], p2[0]], [p1[1], p2[1]]) + # 通过另外两个点距离edge的距离,来决定edge对应的平行线应该过p2还是p3 + if point_dist_to_line(p0, p1, p2) > point_dist_to_line(p0, p1, p3): + # 平行线经过p2 - parallel lines through p2 + if edge[1] == 0: + edge_opposite = [1, 0, -p2[0]] + else: + edge_opposite = [edge[0], -1, p2[1] - edge[0] * p2[0]] + else: + # 经过p3 - after p3 + if edge[1] == 0: + edge_opposite = [1, 0, -p3[0]] + else: + edge_opposite = [edge[0], -1, p3[1] - edge[0] * p3[0]] + # move forward edge + new_p0 = p0 + new_p1 = p1 + new_p2 = p2 + new_p3 = p3 + new_p2 = line_cross_point(forward_edge, edge_opposite) + if point_dist_to_line(p1, new_p2, p0) > point_dist_to_line(p1, new_p2, p3): + # across p0 + if forward_edge[1] == 0: + forward_opposite = [1, 0, -p0[0]] + else: + forward_opposite = [forward_edge[0], -1, p0[1] - forward_edge[0] * p0[0]] + else: + # across p3 + if forward_edge[1] == 0: + forward_opposite = [1, 0, -p3[0]] + else: + forward_opposite = [forward_edge[0], -1, p3[1] - forward_edge[0] * p3[0]] + new_p0 = line_cross_point(forward_opposite, edge) + new_p3 = line_cross_point(forward_opposite, edge_opposite) + fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0]) + # or move backward edge + new_p0 = p0 + new_p1 = p1 + new_p2 = p2 + new_p3 = p3 + new_p3 = line_cross_point(backward_edge, edge_opposite) + if point_dist_to_line(p0, p3, p1) > point_dist_to_line(p0, p3, p2): + # across p1 + if backward_edge[1] == 0: + backward_opposite = [1, 0, -p1[0]] + else: + backward_opposite = [backward_edge[0], -1, p1[1] - backward_edge[0] * p1[0]] + else: + # across p2 + if backward_edge[1] == 0: + backward_opposite = [1, 0, -p2[0]] + else: + backward_opposite = [backward_edge[0], -1, p2[1] - backward_edge[0] * p2[0]] + new_p1 = line_cross_point(backward_opposite, edge) + new_p2 = line_cross_point(backward_opposite, edge_opposite) + fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0]) + + # 选定面积最小的平行四边形 + areas = [Polygon(t).area for t in fitted_parallelograms] + parallelogram = np.array(fitted_parallelograms[np.argmin(areas)][:-1], dtype=np.float32) + # sort thie polygon + parallelogram_coord_sum = np.sum(parallelogram, axis=1) + min_coord_idx = np.argmin(parallelogram_coord_sum) + parallelogram = parallelogram[ + [min_coord_idx, (min_coord_idx + 1) % 4, (min_coord_idx + 2) % 4, (min_coord_idx + 3) % 4]] + + # 得到外包矩形即旋转角 + rectange = rectangle_from_parallelogram(parallelogram) + rectange, rotate_angle = sort_rectangle(rectange) + + p0_rect, p1_rect, p2_rect, p3_rect = rectange + # 对当前新加入的文本框区域像素点,根据其到矩形四边的距离修改geo_map + for y, x in xy_in_poly: + point = np.array([x, y], dtype=np.float32) + # top + geo_map[y, x, 0] = point_dist_to_line(p0_rect, p1_rect, point) + # right + geo_map[y, x, 1] = point_dist_to_line(p1_rect, p2_rect, point) + # down + geo_map[y, x, 2] = point_dist_to_line(p2_rect, p3_rect, point) + # left + geo_map[y, x, 3] = point_dist_to_line(p3_rect, p0_rect, point) + # angle + geo_map[y, x, 4] = rotate_angle + return score_map, geo_map, training_mask + + +def generator(index, + input_size=512, + background_ratio=3. / 8, # 纯背景样本比例 + random_scale=np.array([0.5, 1, 2.0, 3.0]), # 提取多尺度图片信息 + image_list=None): + try: + im_fn = image_list[index] + im = cv2.imread(im_fn) + if im is None: + print("can't find image") + return None, None, None, None, None + h, w, _ = im.shape + # 所以要把gt去掉 + txt_fn = im_fn.replace(os.path.basename(im_fn).split('.')[1], 'txt') + if not os.path.exists(txt_fn): + print('text file {} does not exists'.format(txt_fn)) + return None, None, None, None, None + # 加载标注框信息 + text_polys, text_tags = load_annoataion(txt_fn) + + text_polys, text_tags = check_and_validate_polys(text_polys, text_tags, (h, w)) + + # random scale this image,随机选择一种尺度 + rd_scale = np.random.choice(random_scale) + im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) + text_polys *= rd_scale + + # random crop a area from image,3/8的选中的概率,裁剪纯背景的图片 + if np.random.rand() < background_ratio: + # crop background + im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=True) + if text_polys.shape[0] > 0: + # print("cannot find background") + return None, None, None, None, None + # pad and resize image + new_h, new_w, _ = im.shape + max_h_w_i = np.max([new_h, new_w, input_size]) + im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8) + im_padded[:new_h, :new_w, :] = im.copy() + # 将裁剪后图片扩充成512*512的图片 + im = cv2.resize(im_padded, dsize=(input_size, input_size)) + score_map = np.zeros((input_size, input_size), dtype=np.uint8) + geo_map_channels = 5 if cfg.geometry == 'RBOX' else 8 + geo_map = np.zeros((input_size, input_size, geo_map_channels), dtype=np.float32) + training_mask = np.ones((input_size, input_size), dtype=np.uint8) + else: + # 5 / 8的选中的概率,裁剪含文本信息的图片 + im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=False) + if text_polys.shape[0] == 0: + # print("cannot find txt ground") + return None, None, None, None, None + h, w, _ = im.shape + # pad the image to the training input size or the longer side of image + new_h, new_w, _ = im.shape + max_h_w_i = np.max([new_h, new_w, input_size]) + im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8) + im_padded[:new_h, :new_w, :] = im.copy() + im = im_padded + # resize the image to input size + # 填充,resize图像至设定尺寸 + new_h, new_w, _ = im.shape + resize_h = input_size + resize_w = input_size + im = cv2.resize(im, dsize=(resize_w, resize_h)) + # 将文本框坐标标签等比例修改 + resize_ratio_3_x = resize_w / float(new_w) + resize_ratio_3_y = resize_h / float(new_h) + text_polys[:, :, 0] *= resize_ratio_3_x + text_polys[:, :, 1] *= resize_ratio_3_y + new_h, new_w, _ = im.shape + score_map, geo_map, training_mask = generate_rbox((new_h, new_w), text_polys, text_tags) + + # 将一个样本的样本内容和标签信息append + images = im[:,:,::-1].astype(np.float32) + # 文件名加入列表 + image_fns = im_fn + # 512*512取提取四分之一行列 + score_maps = score_map[::4, ::4, np.newaxis].astype(np.float32) + geo_maps = geo_map[::4, ::4, :].astype(np.float32) + training_masks = training_mask[::4, ::4, np.newaxis].astype(np.float32) + # 符合一个样本之后输出 + return images, image_fns, score_maps, geo_maps, training_masks + + except Exception as e: + import traceback + traceback.print_exc() + + # print("Exception is exist!") + return None, None, None, None, None diff --git a/IndicPhotoOCR/detection/east_utils.py b/IndicPhotoOCR/detection/east_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..03c2a3838052579aa766ec00fd6f94fafd5554e4 --- /dev/null +++ b/IndicPhotoOCR/detection/east_utils.py @@ -0,0 +1,283 @@ + +import torch +import os +from torch.nn import init +import cv2 +import numpy as np +import time +import requests + +from IndicPhotoOCR.detection import east_config as cfg +from IndicPhotoOCR.detection import east_preprossing as preprossing +from IndicPhotoOCR.detection import east_locality_aware_nms as locality_aware_nms + + + +# Example usage: +model_info = { + "east": { + "paths": [ cfg.checkpoint, cfg.pretrained_basemodel_path], + "urls" : ["https://github.com/anikde/STocr/releases/download/e0.1.0/epoch_990_checkpoint.pth.tar", "https://github.com/anikde/STocr/releases/download/e0.1.0/mobilenet_v2.pth.tar"] + }, +} + +class ModelManager: + def __init__(self): + # self.root_model_dir = "bharatOCR/detection/" + pass + + def download_model(self, url, path): + response = requests.get(url, stream=True) + if response.status_code == 200: + with open(path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: # Filter out keep-alive chunks + f.write(chunk) + print(f"Downloaded: {path}") + else: + print(f"Failed to download from {url}") + + def ensure_model(self, model_name): + model_paths = model_info[model_name]["paths"] # Changed to handle multiple paths + urls = model_info[model_name]["urls"] # Changed to handle multiple URLs + + + for model_path, url in zip(model_paths, urls): + # full_model_path = os.path.join(self.root_model_dir, model_path) + + # Ensure the model path directory exists + os.makedirs(os.path.dirname(os.path.join(*cfg.pretrained_basemodel_path.split("/"))), exist_ok=True) + + if not os.path.exists(model_path): + print(f"Model not found locally. Downloading {model_name} from {url}...") + self.download_model(url, model_path) + else: + print(f"Model already exists at {model_path}. No need to download.") + + + +# # Initialize ModelManager and ensure Hindi models are downloaded +model_manager = ModelManager() +model_manager.ensure_model("east") + + + +def init_weights(m_list, init_type=cfg.init_type, gain=0.02): + print("EAST <==> Prepare <==> Init Network'{}' <==> Begin".format(cfg.init_type)) + # this will apply to each layer + for m in m_list: + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # good for relu + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, gain) + init.constant_(m.bias.data, 0.0) + + print("EAST <==> Prepare <==> Init Network'{}' <==> Done".format(cfg.init_type)) + + +def Loading_checkpoint(model, optimizer, scheduler, filename='checkpoint.pth.tar'): + """[summary] + [description] + Arguments: + state {[type]} -- [description] a dict describe some params + Keyword Arguments: + filename {str} -- [description] (default: {'checkpoint.pth.tar'}) + """ + weightpath = os.path.abspath(cfg.checkpoint) + print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(weightpath)) + checkpoint = torch.load(weightpath) + start_epoch = checkpoint['epoch'] + 1 + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + scheduler.load_state_dict(checkpoint['scheduler']) + print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(weightpath)) + + return start_epoch + + +def save_checkpoint(epoch, model, optimizer, scheduler, filename='checkpoint.pth.tar'): + """[summary] + [description] + Arguments: + state {[type]} -- [description] a dict describe some params + Keyword Arguments: + filename {str} -- [description] (default: {'checkpoint.pth.tar'}) + """ + print('EAST <==> Save weight - epoch {} <==> Begin'.format(epoch)) + state = { + 'epoch': epoch, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict() + } + weight_dir = cfg.save_model_path + if not os.path.exists(weight_dir): + os.mkdir(weight_dir) + filename = 'epoch_' + str(epoch) + '_checkpoint.pth.tar' + file_path = os.path.join(weight_dir, filename) + torch.save(state, file_path) + print('EAST <==> Save weight - epoch {} <==> Done'.format(epoch)) + + +class Regularization(torch.nn.Module): + def __init__(self, model, weight_decay, p=2): + super(Regularization, self).__init__() + if weight_decay < 0: + print("param weight_decay can not <0") + exit(0) + self.model = model + self.weight_decay = weight_decay + self.p = p + self.weight_list = self.get_weight(model) + # self.weight_info(self.weight_list) + + def to(self, device): + self.device = device + super().to(device) + return self + + def forward(self, model): + self.weight_list = self.get_weight(model) # 获得最新的权重 + reg_loss = self.regularization_loss(self.weight_list, self.weight_decay, p=self.p) + return reg_loss + + def get_weight(self, model): + weight_list = [] + for name, param in model.named_parameters(): + if 'weight' in name: + weight = (name, param) + weight_list.append(weight) + return weight_list + + def regularization_loss(self, weight_list, weight_decay, p=2): + reg_loss = 0 + for name, w in weight_list: + l2_reg = torch.norm(w, p=p) + reg_loss = reg_loss + l2_reg + + reg_loss = weight_decay * reg_loss + return reg_loss + + def weight_info(self, weight_list): + print("---------------regularization weight---------------") + for name, w in weight_list: + print(name) + print("---------------------------------------------------") + + +def resize_image(im, max_side_len=2400): + ''' + resize image to a size multiple of 32 which is required by the network + :param im: the resized image + :param max_side_len: limit of max image size to avoid out of memory in gpu + :return: the resized image and the resize ratio + ''' + h, w, _ = im.shape + + resize_w = w + resize_h = h + + # limit the max side + """ + if max(resize_h, resize_w) > max_side_len: + ratio = float(max_side_len) / resize_h if resize_h > resize_w else float(max_side_len) / resize_w + else: + ratio = 1. + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + """ + + resize_h = resize_h if resize_h % 32 == 0 else (resize_h // 32 - 1) * 32 + resize_w = resize_w if resize_w % 32 == 0 else (resize_w // 32 - 1) * 32 + #resize_h, resize_w = 512, 512 + im = cv2.resize(im, (int(resize_w), int(resize_h))) + + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return im, (ratio_h, ratio_w) + + +def detect(score_map, geo_map, timer, score_map_thresh=0.8, box_thresh=0.1, nms_thres=0.2): + ''' + restore text boxes from score map and geo map + :param score_map: + :param geo_map: + :param timer: + :param score_map_thresh: threshhold for score map + :param box_thresh: threshhold for boxes + :param nms_thres: threshold for nms + :return: + ''' + + # score_map 和 geo_map 的维数进行调整 + if len(score_map.shape) == 4: + score_map = score_map[0, :, :, 0] + geo_map = geo_map[0, :, :, :] + # filter the score map + xy_text = np.argwhere(score_map > score_map_thresh) + # sort the text boxes via the y axis + xy_text = xy_text[np.argsort(xy_text[:, 0])] + # restore + start = time.time() + text_box_restored = preprossing.restore_rectangle(xy_text[:, ::-1] * 4, + geo_map[xy_text[:, 0], xy_text[:, 1], :]) # N*4*2 + # print('{} text boxes before nms'.format(text_box_restored.shape[0])) + boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) + boxes[:, :8] = text_box_restored.reshape((-1, 8)) + boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] + timer['restore'] = time.time() - start + # nms part + start = time.time() + boxes = locality_aware_nms.nms_locality(boxes.astype(np.float64), nms_thres) + timer['nms'] = time.time() - start + # print(timer['nms']) + if boxes.shape[0] == 0: + return None, timer + + # here we filter some low score boxes by the average score map, this is different from the orginal paper + for i, box in enumerate(boxes): + mask = np.zeros_like(score_map, dtype=np.uint8) + cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1) + boxes[i, 8] = cv2.mean(score_map, mask)[0] + boxes = boxes[boxes[:, 8] > box_thresh] + return boxes, timer + + +def sort_poly(p): + min_axis = np.argmin(np.sum(p, axis=1)) + p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]] + if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]): + return p + else: + return p[[0, 3, 2, 1]] + + +def mean_image_subtraction(images, means=cfg.means): + ''' + image normalization + :param images: bs * w * h * channel + :param means: + :return: + ''' + num_channels = images.data.shape[1] + if len(means) != num_channels: + raise ValueError('len(means) must match the number of channels') + for i in range(num_channels): + images.data[:, i, :, :] -= means[i] + + return images diff --git a/IndicPhotoOCR/ocr.py b/IndicPhotoOCR/ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..7da65446ef69feb6d16c678b127dce727a648d77 --- /dev/null +++ b/IndicPhotoOCR/ocr.py @@ -0,0 +1,154 @@ +import sys +import os +import torch +from PIL import Image +import cv2 +import numpy as np + + +from IndicPhotoOCR.detection.east_detector import EASTdetector +from IndicPhotoOCR.script_identification.CLIP_identifier import CLIPidentifier +from IndicPhotoOCR.recognition.parseq_recogniser import PARseqrecogniser +import IndicPhotoOCR.detection.east_config as cfg + + +class OCR: + def __init__(self, device='cuda:0', verbose=False): + # self.detect_model_checkpoint = detect_model_checkpoint + self.device = device + self.verbose = verbose + # self.image_path = image_path + self.detector = EASTdetector() + self.recogniser = PARseqrecogniser() + self.identifier = CLIPidentifier() + + def detect(self, image_path, detect_model_checkpoint=cfg.checkpoint): + """Run the detection model to get bounding boxes of text areas.""" + + if self.verbose: + print("Running text detection...") + detections = self.detector.detect(image_path, detect_model_checkpoint, self.device) + # print(detections) + return detections['detections'] + + def visualize_detection(self, image_path, detections, save_path=None, show=False): + # Default save path if none is provided + default_save_path = "test.png" + path_to_save = save_path if save_path is not None else default_save_path + + # Get the directory part of the path + directory = os.path.dirname(path_to_save) + + # Check if the directory exists, and create it if it doesn’t + if directory and not os.path.exists(directory): + os.makedirs(directory) + print(f"Created directory: {directory}") + + # Read the image and draw bounding boxes + image = cv2.imread(image_path) + for box in detections: + # Convert list of points to a numpy array with int type + points = np.array(box, np.int32) + points = points.reshape((-1, 1, 2)) # Reshape for cv2.polylines + # Draw the polygon + cv2.polylines(image, [points], isClosed=True, color=(0, 255, 0), thickness=3) + + # Show the image if 'show' is True + if show: + plt.figure(figsize=(10, 10)) + plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + plt.axis("off") + plt.show() + + # Save the annotated image + cv2.imwrite(path_to_save, image) + print(f"Image saved at: {path_to_save}") + + def crop_and_identify_script(self, image, bbox): + """ + Crop a text area from the image and identify its script language. + + Args: + image (PIL.Image): The full image. + bbox (list): List of four corner points, each a [x, y] pair. + + Returns: + str: Identified script language. + """ + # Extract x and y coordinates from the four corner points + x_coords = [point[0] for point in bbox] + y_coords = [point[1] for point in bbox] + + # Get the bounding box coordinates (min and max) + x_min, y_min = min(x_coords), min(y_coords) + x_max, y_max = max(x_coords), max(y_coords) + + # Crop the image based on the bounding box + cropped_image = image.crop((x_min, y_min, x_max, y_max)) + root_image_dir = "IndicPhotoOCR/script_identification" + os.makedirs(f"{root_image_dir}/images", exist_ok=True) + # Temporarily save the cropped image to pass to the script model + cropped_path = f'{root_image_dir}/images/temp_crop_{x_min}_{y_min}.jpg' + cropped_image.save(cropped_path) + + # Predict script language, here we assume "hindi" as the model name + if self.verbose: + print("Identifying script for the cropped area...") + script_lang = self.identifier.identify(cropped_path, "hindi") # Use "hindi" as the model name + # print(script_lang) + + # Clean up temporary file + # os.remove(cropped_path) + + return script_lang, cropped_path + + def recognise(self, cropped_image_path, script_lang): + """Recognize text in a cropped image area using the identified script.""" + if self.verbose: + print("Recognizing text in detected area...") + recognized_text = self.recogniser.recognise(script_lang, cropped_image_path, script_lang, self.verbose) + # print(recognized_text) + return recognized_text + + def ocr(self, image_path): + """Process the image by detecting text areas, identifying script, and recognizing text.""" + recognized_words = [] + image = Image.open(image_path) + + # Run detection + detections = self.detect(image_path) + + # Process each detected text area + for bbox in detections: + # Crop and identify script language + script_lang, cropped_path = self.crop_and_identify_script(image, bbox) + + # Check if the script language is valid + if script_lang: + + # Recognize text + recognized_word = self.recognise(cropped_path, script_lang) + recognized_words.append(recognized_word) + + if self.verbose: + print(f"Recognized word: {recognized_word}") + + return recognized_words + +if __name__ == '__main__': + # detect_model_checkpoint = 'bharatSTR/East/tmp/epoch_990_checkpoint.pth.tar' + sample_image_path = 'test_images/image_141.jpg' + cropped_image_path = 'test_images/cropped_image/image_141_0.jpg' + + ocr = OCR(device="cpu", verbose=False) + + # detections = ocr.detect(sample_image_path) + # print(detections) + + # ocr.visualize_detection(sample_image_path, detections) + + # recognition = ocr.recognise(cropped_image_path, "hindi") + # print(recognition) + + recognised_words = ocr.ocr(sample_image_path) + print(recognised_words) \ No newline at end of file diff --git a/IndicPhotoOCR/recognition/__init__.py b/IndicPhotoOCR/recognition/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/IndicPhotoOCR/recognition/parseq_recogniser.py b/IndicPhotoOCR/recognition/parseq_recogniser.py new file mode 100644 index 0000000000000000000000000000000000000000..102448006b8afb2e67304c926bd07341a84cccc9 --- /dev/null +++ b/IndicPhotoOCR/recognition/parseq_recogniser.py @@ -0,0 +1,215 @@ +import csv +# import fire +import json +import numpy as np +import os +# import pandas as pd +import sys +import torch +import requests + +from dataclasses import dataclass +from PIL import Image +from nltk import edit_distance +from torchvision import transforms as T +from typing import Optional, Callable, Sequence, Tuple +from tqdm import tqdm + + +from IndicPhotoOCR.utils.strhub.data.module import SceneTextDataModule +from IndicPhotoOCR.utils.strhub.models.utils import load_from_checkpoint + + +model_info = { + "assamese": { + "path": "models/assamese.ckpt", + "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/assamese.ckpt", + }, + "bengali": { + "path": "models/bengali.ckpt", + "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/bengali.ckpt", + }, + "hindi": { + "path": "models/hindi.ckpt", + "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/hindi.ckpt", + }, + "gujarati": { + "path": "models/gujarati.ckpt", + "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/gujarati.ckpt", + }, + "marathi": { + "path": "models/marathi.ckpt", + "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/marathi.ckpt", + }, + "odia": { + "path": "models/odia.ckpt", + "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/odia.ckpt", + }, + "punjabi": { + "path": "models/punjabi.ckpt", + "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/punjabi.ckpt", + }, + "tamil": { + "path": "models/tamil.ckpt", + "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/tamil.ckpt", + }, + "telugu": { + "path": "models/telugu.ckpt", + "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/telugu.ckpt", + } +} + +class PARseqrecogniser: + def __init__(self): + pass + + def get_transform(self, img_size: Tuple[int], augment: bool = False, rotation: int = 0): + transforms = [] + if augment: + from .augment import rand_augment_transform + transforms.append(rand_augment_transform()) + if rotation: + transforms.append(lambda img: img.rotate(rotation, expand=True)) + transforms.extend([ + T.Resize(img_size, T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(0.5, 0.5) + ]) + return T.Compose(transforms) + + + def load_model(self, device, checkpoint): + model = load_from_checkpoint(checkpoint).eval().to(device) + return model + + def get_model_output(self, device, model, image_path): + hp = model.hparams + transform = self.get_transform(hp.img_size, rotation=0) + + image_name = image_path.split("/")[-1] + img = Image.open(image_path).convert('RGB') + img = transform(img) + logits = model(img.unsqueeze(0).to(device)) + probs = logits.softmax(-1) + preds, probs = model.tokenizer.decode(probs) + text = model.charset_adapter(preds[0]) + scores = probs[0].detach().cpu().numpy() + + return text + + # Ensure model file exists; download directly if not + def ensure_model(self, model_name): + model_path = model_info[model_name]["path"] + url = model_info[model_name]["url"] + root_model_dir = "IndicPhotoOCR/recognition/" + model_path = os.path.join(root_model_dir, model_path) + + if not os.path.exists(model_path): + print(f"Model not found locally. Downloading {model_name} from {url}...") + + # Start the download with a progress bar + response = requests.get(url, stream=True) + total_size = int(response.headers.get('content-length', 0)) + os.makedirs(f"{root_model_dir}/models", exist_ok=True) + + with open(model_path, "wb") as f, tqdm( + desc=model_name, + total=total_size, + unit='B', + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in response.iter_content(chunk_size=1024): + f.write(data) + bar.update(len(data)) + + print(f"Downloaded model for {model_name}.") + + return model_path + + def bstr(checkpoint, language, image_dir, save_dir): + """ + Runs the OCR model to process images and save the output as a JSON file. + + Args: + checkpoint (str): Path to the model checkpoint file. + language (str): Language code (e.g., 'hindi', 'english'). + image_dir (str): Directory containing the images to process. + save_dir (str): Directory where the output JSON file will be saved. + + Example usage: + python your_script.py --checkpoint /path/to/checkpoint.ckpt --language hindi --image_dir /path/to/images --save_dir /path/to/save + """ + device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") + + if language != "english": + model = load_model(device, checkpoint) + else: + model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().to(device) + + parseq_dict = {} + for image_path in tqdm(os.listdir(image_dir)): + assert os.path.exists(os.path.join(image_dir, image_path)) == True, f"{image_path}" + text = get_model_output(device, model, os.path.join(image_dir, image_path), language=f"{language}") + + filename = image_path.split('/')[-1] + parseq_dict[filename] = text + + os.makedirs(save_dir, exist_ok=True) + with open(f"{save_dir}/{language}_test.json", 'w') as json_file: + json.dump(parseq_dict, json_file, indent=4, ensure_ascii=False) + + + def bstr_onImage(checkpoint, language, image_path): + """ + Runs the OCR model to process images and save the output as a JSON file. + + Args: + checkpoint (str): Path to the model checkpoint file. + language (str): Language code (e.g., 'hindi', 'english'). + image_dir (str): Directory containing the images to process. + save_dir (str): Directory where the output JSON file will be saved. + + Example usage: + python your_script.py --checkpoint /path/to/checkpoint.ckpt --language hindi --image_dir /path/to/images --save_dir /path/to/save + """ + device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") + + if language != "english": + model = load_model(device, checkpoint) + else: + model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().to(device) + + # parseq_dict = {} + # for image_path in tqdm(os.listdir(image_dir)): + # assert os.path.exists(os.path.join(image_dir, image_path)) == True, f"{image_path}" + text = get_model_output(device, model, image_path, language=f"{language}") + + return text + + + def recognise(self, checkpoint: str, image_path: str, language: str, verbose: bool) -> str: + """ + Loads the desired model and returns the recognized word from the specified image. + + Args: + checkpoint (str): Path to the model checkpoint file. + language (str): Language code (e.g., 'hindi', 'english'). + image_path (str): Path to the image for which text recognition is needed. + + Returns: + str: The recognized text from the image. + """ + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + if language != "english": + model_path = self.ensure_model(checkpoint) + model = self.load_model(device, model_path) + else: + model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True, verbose=verbose).eval().to(device) + + recognized_text = self.get_model_output(device, model, image_path) + + return recognized_text +# if __name__ == '__main__': +# fire.Fire(main) \ No newline at end of file diff --git a/IndicPhotoOCR/script_identification/CLIP_identifier.py b/IndicPhotoOCR/script_identification/CLIP_identifier.py new file mode 100644 index 0000000000000000000000000000000000000000..c47531f88de065104461ef2f909ab80150e65e79 --- /dev/null +++ b/IndicPhotoOCR/script_identification/CLIP_identifier.py @@ -0,0 +1,201 @@ + +import torch +import clip +from PIL import Image +from io import BytesIO +import os +import requests + +# Model information dictionary containing model paths and language subcategories +model_info = { + "hindi": { + "path": "models/clip_finetuned_hindienglish_real.pth", + "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglish_real.pth", + "subcategories": ["hindi", "english"] + }, + "hinengasm": { + "path": "models/clip_finetuned_hindienglishassamese_real.pth", + "url": "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishassamese_real.pth", + "subcategories": ["hindi", "english", "assamese"] + }, + "hinengben": { + "path": "models/clip_finetuned_hindienglishbengali_real.pth", + "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishbengali_real.pth", + "subcategories": ["hindi", "english", "bengali"] + }, + "hinengguj": { + "path": "models/clip_finetuned_hindienglishgujarati_real.pth", + "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishgujarati_real.pth", + "subcategories": ["hindi", "english", "gujarati"] + }, + "hinengkan": { + "path": "models/clip_finetuned_hindienglishkannada_real.pth", + "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishkannada_real.pth", + "subcategories": ["hindi", "english", "kannada"] + }, + "hinengmal": { + "path": "models/clip_finetuned_hindienglishmalayalam_real.pth", + "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmalayalam_real.pth", + "subcategories": ["hindi", "english", "malayalam"] + }, + "hinengmar": { + "path": "models/clip_finetuned_hindienglishmarathi_real.pth", + "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmarathi_real.pth", + "subcategories": ["hindi", "english", "marathi"] + }, + "hinengmei": { + "path": "models/clip_finetuned_hindienglishmeitei_real.pth", + "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmeitei_real.pth", + "subcategories": ["hindi", "english", "meitei"] + }, + "hinengodi": { + "path": "models/clip_finetuned_hindienglishodia_real.pth", + "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishodia_real.pth", + "subcategories": ["hindi", "english", "odia"] + }, + "hinengpun": { + "path": "models/clip_finetuned_hindienglishpunjabi_real.pth", + "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishpunjabi_real.pth", + "subcategories": ["hindi", "english", "punjabi"] + }, + "hinengtam": { + "path": "models/clip_finetuned_hindienglishtamil_real.pth", + "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtamil_real.pth", + "subcategories": ["hindi", "english", "tamil"] + }, + "hinengtel": { + "path": "models/clip_finetuned_hindienglishtelugu_real.pth", + "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtelugu_real.pth", + "subcategories": ["hindi", "english", "telugu"] + }, + "hinengurd": { + "path": "models/clip_finetuned_hindienglishurdu_real.pth", + "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishurdu_real.pth", + "subcategories": ["hindi", "english", "urdu"] + }, + + +} + + +# Set device to CUDA if available, otherwise use CPU +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +clip_model, preprocess = clip.load("ViT-B/32", device=device) + +class CLIPFineTuner(torch.nn.Module): + """ + Fine-tuning class for the CLIP model to adapt to specific tasks. + + Attributes: + model (torch.nn.Module): The CLIP model to be fine-tuned. + classifier (torch.nn.Linear): A linear classifier to map features to the desired number of classes. + """ + def __init__(self, model, num_classes): + """ + Initializes the fine-tuner with the CLIP model and classifier. + + Args: + model (torch.nn.Module): The base CLIP model. + num_classes (int): The number of target classes for classification. + """ + super(CLIPFineTuner, self).__init__() + self.model = model + self.classifier = torch.nn.Linear(model.visual.output_dim, num_classes) + + def forward(self, x): + """ + Forward pass for image classification. + + Args: + x (torch.Tensor): Preprocessed input tensor for an image. + + Returns: + torch.Tensor: Logits for each class. + """ + with torch.no_grad(): + features = self.model.encode_image(x).float() # Extract image features from CLIP model + return self.classifier(features) # Return class logits + +class CLIPidentifier: + def __init__(self): + pass + + # Ensure model file exists; download directly if not + def ensure_model(self, model_name): + model_path = model_info[model_name]["path"] + url = model_info[model_name]["url"] + root_model_dir = "IndicPhotoOCR/script_identification/" + model_path = os.path.join(root_model_dir, model_path) + + if not os.path.exists(model_path): + print(f"Model not found locally. Downloading {model_name} from {url}...") + response = requests.get(url, stream=True) + os.makedirs(f"{root_model_dir}/models", exist_ok=True) + with open(f"{model_path}", "wb") as f: + f.write(response.content) + print(f"Downloaded model for {model_name}.") + + return model_path + + # Prediction function to verify and load the model + def identify(self, image_path, model_name): + """ + Predicts the class of an input image using a fine-tuned CLIP model. + + Args: + image_path (str): Path to the input image file. + model_name (str): Name of the model (e.g., hineng, hinengpun, hinengguj) as specified in `model_info`. + + Returns: + dict: Contains either `predicted_class` if successful or `error` if an exception occurs. + + Example usage: + result = predict("sample_image.jpg", "hinengguj") + print(result) # Output might be {'predicted_class': 'hindi'} + """ + try: + # Validate model name and retrieve associated subcategories + if model_name not in model_info: + return {"error": "Invalid model name"} + + # Ensure the model file is downloaded and accessible + model_path = self.ensure_model(model_name) + + + subcategories = model_info[model_name]["subcategories"] + num_classes = len(subcategories) + + # Load the fine-tuned model with the specified number of classes + model_ft = CLIPFineTuner(clip_model, num_classes) + model_ft.load_state_dict(torch.load(model_path, map_location=device)) + model_ft = model_ft.to(device) + model_ft.eval() + + # Load and preprocess the image + image = Image.open(image_path).convert("RGB") + input_tensor = preprocess(image).unsqueeze(0).to(device) + + # Run the model and get the prediction + outputs = model_ft(input_tensor) + _, predicted_idx = torch.max(outputs, 1) + predicted_class = subcategories[predicted_idx.item()] + + return predicted_class + + except Exception as e: + return {"error": str(e)} + + +# if __name__ == "__main__": +# import argparse + +# # Argument parser for command line usage +# parser = argparse.ArgumentParser(description="Image classification using CLIP fine-tuned model") +# parser.add_argument("image_path", type=str, help="Path to the input image") +# parser.add_argument("model_name", type=str, choices=model_info.keys(), help="Name of the model (e.g., hineng, hinengpun, hinengguj)") + +# args = parser.parse_args() + +# # Execute prediction with command line inputs +# result = predict(args.image_path, args.model_name) +# print(result) \ No newline at end of file diff --git a/IndicPhotoOCR/script_identification/__init__.py b/IndicPhotoOCR/script_identification/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/IndicPhotoOCR/theme.py b/IndicPhotoOCR/theme.py new file mode 100644 index 0000000000000000000000000000000000000000..ca812bf7d38568c0d831dc1476bffa3486970e33 --- /dev/null +++ b/IndicPhotoOCR/theme.py @@ -0,0 +1,43 @@ +from __future__ import annotations +from typing import Iterable +import gradio as gr +from gradio.themes.base import Base +from gradio.themes.utils import colors, fonts, sizes +import time + + +class Seafoam(Base): + def __init__( + self, + *, + primary_hue: colors.Color | str = colors.emerald, + secondary_hue: colors.Color | str = colors.blue, + neutral_hue: colors.Color | str = colors.gray, + spacing_size: sizes.Size | str = sizes.spacing_md, + radius_size: sizes.Size | str = sizes.radius_md, + text_size: sizes.Size | str = sizes.text_lg, + font: fonts.Font + | str + | Iterable[fonts.Font | str] = ( + fonts.GoogleFont("Quicksand"), + "ui-sans-serif", + "sans-serif", + ), + font_mono: fonts.Font + | str + | Iterable[fonts.Font | str] = ( + fonts.GoogleFont("IBM Plex Mono"), + "ui-monospace", + "monospace", + ), + ): + super().__init__( + primary_hue=primary_hue, + secondary_hue=secondary_hue, + neutral_hue=neutral_hue, + spacing_size=spacing_size, + radius_size=radius_size, + text_size=text_size, + font=font, + font_mono=font_mono, + ) \ No newline at end of file diff --git a/IndicPhotoOCR/utils/strhub/__init__.py b/IndicPhotoOCR/utils/strhub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d740e413e1fb8d335027dd8ff6d3aa393d45e84 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/__init__.py @@ -0,0 +1,2 @@ +# from data.module import SceneTextDataModule +# from model.utils import load_from_checkpoint \ No newline at end of file diff --git a/IndicPhotoOCR/utils/strhub/data/__init__.py b/IndicPhotoOCR/utils/strhub/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36e3d6f0bae7320fb3ae022de41146597a5481b6 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/data/__init__.py @@ -0,0 +1 @@ +# from .module import SceneTextDataModule \ No newline at end of file diff --git a/IndicPhotoOCR/utils/strhub/data/aa_overrides.py b/IndicPhotoOCR/utils/strhub/data/aa_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..ef374e2e4166c3847d80d30bab2b0eb6ba88d70c --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/data/aa_overrides.py @@ -0,0 +1,46 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extends default ops to accept optional parameters.""" +from functools import partial + +from timm.data.auto_augment import _LEVEL_DENOM, LEVEL_TO_ARG, NAME_TO_OP, _randomly_negate, rotate + + +def rotate_expand(img, degrees, **kwargs): + """Rotate operation with expand=True to avoid cutting off the characters""" + kwargs['expand'] = True + return rotate(img, degrees, **kwargs) + + +def _level_to_arg(level, hparams, key, default): + magnitude = hparams.get(key, default) + level = (level / _LEVEL_DENOM) * magnitude + level = _randomly_negate(level) + return (level,) + + +def apply(): + # Overrides + NAME_TO_OP.update({ + 'Rotate': rotate_expand, + }) + LEVEL_TO_ARG.update({ + 'Rotate': partial(_level_to_arg, key='rotate_deg', default=30.0), + 'ShearX': partial(_level_to_arg, key='shear_x_pct', default=0.3), + 'ShearY': partial(_level_to_arg, key='shear_y_pct', default=0.3), + 'TranslateXRel': partial(_level_to_arg, key='translate_x_pct', default=0.45), + 'TranslateYRel': partial(_level_to_arg, key='translate_y_pct', default=0.45), + }) diff --git a/IndicPhotoOCR/utils/strhub/data/augment.py b/IndicPhotoOCR/utils/strhub/data/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..ed8832503693863907640b83d5771de90ed6e773 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/data/augment.py @@ -0,0 +1,112 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import imgaug.augmenters as iaa +import numpy as np +from PIL import Image, ImageFilter + +from timm.data import auto_augment + +from strhub.data import aa_overrides + +aa_overrides.apply() + +_OP_CACHE = {} + + +def _get_op(key, factory): + try: + op = _OP_CACHE[key] + except KeyError: + op = factory() + _OP_CACHE[key] = op + return op + + +def _get_param(level, img, max_dim_factor, min_level=1): + max_level = max(min_level, max_dim_factor * max(img.size)) + return round(min(level, max_level)) + + +def gaussian_blur(img, radius, **__): + radius = _get_param(radius, img, 0.02) + key = 'gaussian_blur_' + str(radius) + op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius)) + return img.filter(op) + + +def motion_blur(img, k, **__): + k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values + key = 'motion_blur_' + str(k) + op = _get_op(key, lambda: iaa.MotionBlur(k)) + return Image.fromarray(op(image=np.asarray(img))) + + +def gaussian_noise(img, scale, **_): + scale = _get_param(scale, img, 0.25) | 1 # bin to odd values + key = 'gaussian_noise_' + str(scale) + op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale)) + return Image.fromarray(op(image=np.asarray(img))) + + +def poisson_noise(img, lam, **_): + lam = _get_param(lam, img, 0.2) | 1 # bin to odd values + key = 'poisson_noise_' + str(lam) + op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam)) + return Image.fromarray(op(image=np.asarray(img))) + + +def _level_to_arg(level, _hparams, max): + level = max * level / auto_augment._LEVEL_DENOM + return (level,) + + +_RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy() +_RAND_TRANSFORMS.remove('SharpnessIncreasing') # remove, interferes with *blur ops +_RAND_TRANSFORMS.extend([ + 'GaussianBlur', + # 'MotionBlur', + # 'GaussianNoise', + 'PoissonNoise', +]) +auto_augment.LEVEL_TO_ARG.update({ + 'GaussianBlur': partial(_level_to_arg, max=4), + 'MotionBlur': partial(_level_to_arg, max=20), + 'GaussianNoise': partial(_level_to_arg, max=0.1 * 255), + 'PoissonNoise': partial(_level_to_arg, max=40), +}) +auto_augment.NAME_TO_OP.update({ + 'GaussianBlur': gaussian_blur, + 'MotionBlur': motion_blur, + 'GaussianNoise': gaussian_noise, + 'PoissonNoise': poisson_noise, +}) + + +def rand_augment_transform(magnitude=5, num_layers=3): + # These are tuned for magnitude=5, which means that effective magnitudes are half of these values. + hparams = { + 'rotate_deg': 30, + 'shear_x_pct': 0.9, + 'shear_y_pct': 0.2, + 'translate_x_pct': 0.10, + 'translate_y_pct': 0.30, + } + ra_ops = auto_augment.rand_augment_ops(magnitude, hparams=hparams, transforms=_RAND_TRANSFORMS) + # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice) + choice_weights = [1.0 / len(ra_ops) for _ in range(len(ra_ops))] + return auto_augment.RandAugment(ra_ops, num_layers, choice_weights) diff --git a/IndicPhotoOCR/utils/strhub/data/dataset.py b/IndicPhotoOCR/utils/strhub/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..65954c23127f02d1393179ddbc9fb175c88b8de9 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/data/dataset.py @@ -0,0 +1,148 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import glob +import io +import logging +import unicodedata +from pathlib import Path, PurePath +from typing import Callable, Optional, Union + +import lmdb +from PIL import Image + +from torch.utils.data import ConcatDataset, Dataset + +from IndicPhotoOCR.utils.strhub.data.utils import CharsetAdapter + +log = logging.getLogger(__name__) + + +def build_tree_dataset(root: Union[PurePath, str], *args, **kwargs): + try: + kwargs.pop('root') # prevent 'root' from being passed via kwargs + except KeyError: + pass + root = Path(root).absolute() + log.info(f'dataset root:\t{root}') + datasets = [] + for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True): + mdb = Path(mdb) + ds_name = str(mdb.parent.relative_to(root)) + ds_root = str(mdb.parent.absolute()) + dataset = LmdbDataset(ds_root, *args, **kwargs) + log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}') + datasets.append(dataset) + return ConcatDataset(datasets) + + +class LmdbDataset(Dataset): + """Dataset interface to an LMDB database. + + It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned + as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset. + Labels are transformed according to the charset. + """ + + def __init__( + self, + root: str, + charset: str, + max_label_len: int, + min_image_dim: int = 0, + remove_whitespace: bool = True, + normalize_unicode: bool = True, + unlabelled: bool = False, + transform: Optional[Callable] = None, + ): + self._env = None + self.root = root + self.unlabelled = unlabelled + self.transform = transform + self.labels = [] + self.filtered_index_list = [] + self.num_samples = self._preprocess_labels( + charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim + ) + + def __del__(self): + if self._env is not None: + self._env.close() + self._env = None + + def _create_env(self): + return lmdb.open( + self.root, max_readers=1, readonly=True, create=False, readahead=False, meminit=False, lock=False + ) + + @property + def env(self): + if self._env is None: + self._env = self._create_env() + return self._env + + def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim): + charset_adapter = CharsetAdapter(charset) + with self._create_env() as env, env.begin() as txn: + num_samples = int(txn.get('num-samples'.encode())) + if self.unlabelled: + return num_samples + for index in range(num_samples): + index += 1 # lmdb starts with 1 + label_key = f'label-{index:09d}'.encode() + label = txn.get(label_key).decode() + # Normally, whitespace is removed from the labels. + if remove_whitespace: + label = ''.join(label.split()) + # Normalize unicode composites (if any) and convert to compatible ASCII characters + if normalize_unicode: + label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode() + # Filter by length before removing unsupported characters. The original label might be too long. + if len(label) > max_label_len: + continue + label = charset_adapter(label) + # We filter out samples which don't contain any supported characters + if not label: + continue + # Filter images that are too small. + if min_image_dim > 0: + img_key = f'image-{index:09d}'.encode() + buf = io.BytesIO(txn.get(img_key)) + w, h = Image.open(buf).size + if w < self.min_image_dim or h < self.min_image_dim: + continue + self.labels.append(label) + self.filtered_index_list.append(index) + return len(self.labels) + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + if self.unlabelled: + label = index + else: + label = self.labels[index] + index = self.filtered_index_list[index] + + img_key = f'image-{index:09d}'.encode() + with self.env.begin() as txn: + imgbuf = txn.get(img_key) + buf = io.BytesIO(imgbuf) + img = Image.open(buf).convert('RGB') + + if self.transform is not None: + img = self.transform(img) + + return img, label diff --git a/IndicPhotoOCR/utils/strhub/data/module.py b/IndicPhotoOCR/utils/strhub/data/module.py new file mode 100644 index 0000000000000000000000000000000000000000..336a4ba467dd59d9b146e5eebf38e7225e348b12 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/data/module.py @@ -0,0 +1,157 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import PurePath +from typing import Callable, Optional, Sequence + +from torch.utils.data import DataLoader +from torchvision import transforms as T + +import pytorch_lightning as pl + +from IndicPhotoOCR.utils.strhub.data.dataset import LmdbDataset, build_tree_dataset + + +class SceneTextDataModule(pl.LightningDataModule): + TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80') + TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80') + TEST_NEW = ('ArT', 'COCOv1.4', 'Uber') + TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW)) + + def __init__( + self, + root_dir: str, + train_dir: str, + img_size: Sequence[int], + max_label_length: int, + charset_train: str, + charset_test: str, + batch_size: int, + num_workers: int, + augment: bool, + remove_whitespace: bool = True, + normalize_unicode: bool = True, + min_image_dim: int = 0, + rotation: int = 0, + collate_fn: Optional[Callable] = None, + ): + super().__init__() + self.root_dir = root_dir + self.train_dir = train_dir + self.img_size = tuple(img_size) + self.max_label_length = max_label_length + self.charset_train = charset_train + self.charset_test = charset_test + self.batch_size = batch_size + self.num_workers = num_workers + self.augment = augment + self.remove_whitespace = remove_whitespace + self.normalize_unicode = normalize_unicode + self.min_image_dim = min_image_dim + self.rotation = rotation + self.collate_fn = collate_fn + self._train_dataset = None + self._val_dataset = None + + @staticmethod + def get_transform(img_size: tuple[int], augment: bool = False, rotation: int = 0): + transforms = [] + if augment: + from .augment import rand_augment_transform + + transforms.append(rand_augment_transform()) + if rotation: + transforms.append(lambda img: img.rotate(rotation, expand=True)) + transforms.extend([ + T.Resize(img_size, T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(0.5, 0.5), + ]) + return T.Compose(transforms) + + @property + def train_dataset(self): + if self._train_dataset is None: + transform = self.get_transform(self.img_size, self.augment) + root = PurePath(self.root_dir, 'train', self.train_dir) + self._train_dataset = build_tree_dataset( + root, + self.charset_train, + self.max_label_length, + self.min_image_dim, + self.remove_whitespace, + self.normalize_unicode, + transform=transform, + ) + return self._train_dataset + + @property + def val_dataset(self): + if self._val_dataset is None: + transform = self.get_transform(self.img_size) + root = PurePath(self.root_dir, 'val') + self._val_dataset = build_tree_dataset( + root, + self.charset_test, + self.max_label_length, + self.min_image_dim, + self.remove_whitespace, + self.normalize_unicode, + transform=transform, + ) + return self._val_dataset + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + persistent_workers=self.num_workers > 0, + pin_memory=True, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=self.num_workers > 0, + pin_memory=True, + collate_fn=self.collate_fn, + ) + + def test_dataloaders(self, subset): + transform = self.get_transform(self.img_size, rotation=self.rotation) + root = PurePath(self.root_dir, 'test') + datasets = { + s: LmdbDataset( + str(root / s), + self.charset_test, + self.max_label_length, + self.min_image_dim, + self.remove_whitespace, + self.normalize_unicode, + transform=transform, + ) + for s in subset + } + return { + k: DataLoader( + v, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=self.collate_fn + ) + for k, v in datasets.items() + } diff --git a/IndicPhotoOCR/utils/strhub/data/utils.py b/IndicPhotoOCR/utils/strhub/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..16fd30d0bba424361730b9d1c33016746de23f68 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/data/utils.py @@ -0,0 +1,150 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from abc import ABC, abstractmethod +from itertools import groupby +from typing import Optional + +import torch +from torch import Tensor +from torch.nn.utils.rnn import pad_sequence + + +class CharsetAdapter: + """Transforms labels according to the target charset.""" + + def __init__(self, target_charset) -> None: + super().__init__() + self.lowercase_only = target_charset == target_charset.lower() + self.uppercase_only = target_charset == target_charset.upper() + self.unsupported = re.compile(f'[^{re.escape(target_charset)}]') + + def __call__(self, label): + if self.lowercase_only: + label = label.lower() + elif self.uppercase_only: + label = label.upper() + # Remove unsupported characters + label = self.unsupported.sub('', label) + return label + + +class BaseTokenizer(ABC): + + def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None: + self._itos = specials_first + tuple(charset) + specials_last + self._stoi = {s: i for i, s in enumerate(self._itos)} + + def __len__(self): + return len(self._itos) + + def _tok2ids(self, tokens: str) -> list[int]: + return [self._stoi[s] for s in tokens] + + def _ids2tok(self, token_ids: list[int], join: bool = True) -> str: + tokens = [self._itos[i] for i in token_ids] + return ''.join(tokens) if join else tokens + + @abstractmethod + def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor: + """Encode a batch of labels to a representation suitable for the model. + + Args: + labels: List of labels. Each can be of arbitrary length. + device: Create tensor on this device. + + Returns: + Batched tensor representation padded to the max label length. Shape: N, L + """ + raise NotImplementedError + + @abstractmethod + def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]: + """Internal method which performs the necessary filtering prior to decoding.""" + raise NotImplementedError + + def decode(self, token_dists: Tensor, raw: bool = False) -> tuple[list[str], list[Tensor]]: + """Decode a batch of token distributions. + + Args: + token_dists: softmax probabilities over the token distribution. Shape: N, L, C + raw: return unprocessed labels (will return list of list of strings) + + Returns: + list of string labels (arbitrary length) and + their corresponding sequence probabilities as a list of Tensors + """ + batch_tokens = [] + batch_probs = [] + for dist in token_dists: + probs, ids = dist.max(-1) # greedy selection + if not raw: + probs, ids = self._filter(probs, ids) + tokens = self._ids2tok(ids, not raw) + batch_tokens.append(tokens) + batch_probs.append(probs) + return batch_tokens, batch_probs + + +class Tokenizer(BaseTokenizer): + BOS = '[B]' + EOS = '[E]' + PAD = '[P]' + + def __init__(self, charset: str) -> None: + specials_first = (self.EOS,) + specials_last = (self.BOS, self.PAD) + super().__init__(charset, specials_first, specials_last) + self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last] + + def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor: + batch = [ + torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device) + for y in labels + ] + return pad_sequence(batch, batch_first=True, padding_value=self.pad_id) + + def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]: + ids = ids.tolist() + try: + eos_idx = ids.index(self.eos_id) + except ValueError: + eos_idx = len(ids) # Nothing to truncate. + # Truncate after EOS + ids = ids[:eos_idx] + probs = probs[: eos_idx + 1] # but include prob. for EOS (if it exists) + return probs, ids + + +class CTCTokenizer(BaseTokenizer): + BLANK = '[B]' + + def __init__(self, charset: str) -> None: + # BLANK uses index == 0 by default + super().__init__(charset, specials_first=(self.BLANK,)) + self.blank_id = self._stoi[self.BLANK] + + def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor: + # We use a padded representation since we don't want to use CUDNN's CTC implementation + batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels] + return pad_sequence(batch, batch_first=True, padding_value=self.blank_id) + + def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]: + # Best path decoding: + ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens + ids = [x for x in ids if x != self.blank_id] # Remove BLANKs + # `probs` is just pass-through since all positions are considered part of the path + return probs, ids diff --git a/IndicPhotoOCR/utils/strhub/models/__init__.py b/IndicPhotoOCR/utils/strhub/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5570dcb032819c28de6b73c4de5bba4109e08c0f --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/__init__.py @@ -0,0 +1 @@ +# from .utils import load_from_checkpoint \ No newline at end of file diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/LICENSE b/IndicPhotoOCR/utils/strhub/models/abinet/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2f1d4adb4889b2719f13ed6edf56aed10246a516 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/abinet/LICENSE @@ -0,0 +1,25 @@ +ABINet for non-commercial purposes + +Copyright (c) 2021, USTC +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/__init__.py b/IndicPhotoOCR/utils/strhub/models/abinet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..604811036fda52d8485eecfebd4ffeb7f7176042 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/abinet/__init__.py @@ -0,0 +1,13 @@ +r""" +Fang, Shancheng, Hongtao, Xie, Yuxin, Wang, Zhendong, Mao, and Yongdong, Zhang. +"Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition." . +In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 7098-7107).2021. + +https://arxiv.org/abs/2103.06495 + +All source files, except `system.py`, are based on the implementation listed below, +and hence are released under the license of the original. + +Source: https://github.com/FangShancheng/ABINet +License: 2-clause BSD License (see included LICENSE file) +""" diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/attention.py b/IndicPhotoOCR/utils/strhub/models/abinet/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..cc8fba0638e7444fdffe964f72d0566c1a5bb818 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/abinet/attention.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn + +from .transformer import PositionalEncoding + + +class Attention(nn.Module): + def __init__(self, in_channels=512, max_length=25, n_feature=256): + super().__init__() + self.max_length = max_length + + self.f0_embedding = nn.Embedding(max_length, in_channels) + self.w0 = nn.Linear(max_length, n_feature) + self.wv = nn.Linear(in_channels, in_channels) + self.we = nn.Linear(in_channels, max_length) + + self.active = nn.Tanh() + self.softmax = nn.Softmax(dim=2) + + def forward(self, enc_output): + enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2) + reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device) + reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S) + reading_order_embed = self.f0_embedding(reading_order) # b,25,512 + + t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256 + t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512 + + attn = self.we(t) # b,256,25 + attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256 + g_output = torch.bmm(attn, enc_output) # b,25,512 + return g_output, attn.view(*attn.shape[:2], 8, 32) + + +def encoder_layer(in_c, out_c, k=3, s=2, p=1): + return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p), + nn.BatchNorm2d(out_c), + nn.ReLU(True)) + + +def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None): + align_corners = None if mode == 'nearest' else True + return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor, + mode=mode, align_corners=align_corners), + nn.Conv2d(in_c, out_c, k, s, p), + nn.BatchNorm2d(out_c), + nn.ReLU(True)) + + +class PositionAttention(nn.Module): + def __init__(self, max_length, in_channels=512, num_channels=64, + h=8, w=32, mode='nearest', **kwargs): + super().__init__() + self.max_length = max_length + self.k_encoder = nn.Sequential( + encoder_layer(in_channels, num_channels, s=(1, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)) + ) + self.k_decoder = nn.Sequential( + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, in_channels, size=(h, w), mode=mode) + ) + + self.pos_encoder = PositionalEncoding(in_channels, dropout=0., max_len=max_length) + self.project = nn.Linear(in_channels, in_channels) + + def forward(self, x): + N, E, H, W = x.size() + k, v = x, x # (N, E, H, W) + + # calculate key vector + features = [] + for i in range(0, len(self.k_encoder)): + k = self.k_encoder[i](k) + features.append(k) + for i in range(0, len(self.k_decoder) - 1): + k = self.k_decoder[i](k) + k = k + features[len(self.k_decoder) - 2 - i] + k = self.k_decoder[-1](k) + + # calculate query vector + # TODO q=f(q,k) + zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E) + q = self.pos_encoder(zeros) # (T, N, E) + q = q.permute(1, 0, 2) # (N, T, E) + q = self.project(q) # (N, T, E) + + # calculate attention + attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W)) + attn_scores = attn_scores / (E ** 0.5) + attn_scores = torch.softmax(attn_scores, dim=-1) + + v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E) + attn_vecs = torch.bmm(attn_scores, v) # (N, T, E) + + return attn_vecs, attn_scores.view(N, -1, H, W) diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/backbone.py b/IndicPhotoOCR/utils/strhub/models/abinet/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..debcabd7f115db0e698a55175a01a0ff0131e10f --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/abinet/backbone.py @@ -0,0 +1,24 @@ +import torch.nn as nn +from torch.nn import TransformerEncoderLayer, TransformerEncoder + +from .resnet import resnet45 +from .transformer import PositionalEncoding + + +class ResTranformer(nn.Module): + def __init__(self, d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', backbone_ln=2): + super().__init__() + self.resnet = resnet45() + self.pos_encoder = PositionalEncoding(d_model, max_len=8 * 32) + encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead, + dim_feedforward=d_inner, dropout=dropout, activation=activation) + self.transformer = TransformerEncoder(encoder_layer, backbone_ln) + + def forward(self, images): + feature = self.resnet(images) + n, c, h, w = feature.shape + feature = feature.view(n, c, -1).permute(2, 0, 1) + feature = self.pos_encoder(feature) + feature = self.transformer(feature) + feature = feature.permute(1, 2, 0).view(n, c, h, w) + return feature diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/model.py b/IndicPhotoOCR/utils/strhub/models/abinet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..cc0cd143d324822c57b897b6e5749024d857fd30 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/abinet/model.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn + + +class Model(nn.Module): + + def __init__(self, dataset_max_length: int, null_label: int): + super().__init__() + self.max_length = dataset_max_length + 1 # additional stop token + self.null_label = null_label + + def _get_length(self, logit, dim=-1): + """ Greed decoder to obtain length from logit""" + out = (logit.argmax(dim=-1) == self.null_label) + abn = out.any(dim) + out = ((out.cumsum(dim) == 1) & out).max(dim)[1] + out = out + 1 # additional end token + out = torch.where(abn, out, out.new_tensor(logit.shape[1], device=out.device)) + return out + + @staticmethod + def _get_padding_mask(length, max_length): + length = length.unsqueeze(-1) + grid = torch.arange(0, max_length, device=length.device).unsqueeze(0) + return grid >= length + + @staticmethod + def _get_location_mask(sz, device=None): + mask = torch.eye(sz, device=device) + mask = mask.float().masked_fill(mask == 1, float('-inf')) + return mask diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/model_abinet_iter.py b/IndicPhotoOCR/utils/strhub/models/abinet/model_abinet_iter.py new file mode 100644 index 0000000000000000000000000000000000000000..1a8523ff6431f991037d56dc8dd72ae67c7bf242 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/abinet/model_abinet_iter.py @@ -0,0 +1,39 @@ +import torch +from torch import nn + +from .model_alignment import BaseAlignment +from .model_language import BCNLanguage +from .model_vision import BaseVision + + +class ABINetIterModel(nn.Module): + def __init__(self, dataset_max_length, null_label, num_classes, iter_size=1, + d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', + v_loss_weight=1., v_attention='position', v_attention_mode='nearest', + v_backbone='transformer', v_num_layers=2, + l_loss_weight=1., l_num_layers=4, l_detach=True, l_use_self_attn=False, + a_loss_weight=1.): + super().__init__() + self.iter_size = iter_size + self.vision = BaseVision(dataset_max_length, null_label, num_classes, v_attention, v_attention_mode, + v_loss_weight, d_model, nhead, d_inner, dropout, activation, v_backbone, v_num_layers) + self.language = BCNLanguage(dataset_max_length, null_label, num_classes, d_model, nhead, d_inner, dropout, + activation, l_num_layers, l_detach, l_use_self_attn, l_loss_weight) + self.alignment = BaseAlignment(dataset_max_length, null_label, num_classes, d_model, a_loss_weight) + + def forward(self, images): + v_res = self.vision(images) + a_res = v_res + all_l_res, all_a_res = [], [] + for _ in range(self.iter_size): + tokens = torch.softmax(a_res['logits'], dim=-1) + lengths = a_res['pt_lengths'] + lengths.clamp_(2, self.language.max_length) # TODO:move to langauge model + l_res = self.language(tokens, lengths) + all_l_res.append(l_res) + a_res = self.alignment(l_res['feature'], v_res['feature']) + all_a_res.append(a_res) + if self.training: + return all_a_res, all_l_res, v_res + else: + return a_res, all_l_res[-1], v_res diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/model_alignment.py b/IndicPhotoOCR/utils/strhub/models/abinet/model_alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..9ccfa95e65dbd7176c8bcee693bb0bcb8ad13c69 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/abinet/model_alignment.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn + +from .model import Model + + +class BaseAlignment(Model): + def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, loss_weight=1.0): + super().__init__(dataset_max_length, null_label) + self.loss_weight = loss_weight + self.w_att = nn.Linear(2 * d_model, d_model) + self.cls = nn.Linear(d_model, num_classes) + + def forward(self, l_feature, v_feature): + """ + Args: + l_feature: (N, T, E) where T is length, N is batch size and d is dim of model + v_feature: (N, T, E) shape the same as l_feature + """ + f = torch.cat((l_feature, v_feature), dim=2) + f_att = torch.sigmoid(self.w_att(f)) + output = f_att * v_feature + (1 - f_att) * l_feature + + logits = self.cls(output) # (N, T, C) + pt_lengths = self._get_length(logits) + + return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight': self.loss_weight, + 'name': 'alignment'} diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/model_language.py b/IndicPhotoOCR/utils/strhub/models/abinet/model_language.py new file mode 100644 index 0000000000000000000000000000000000000000..aa8bb8f60b61ad96dca3c54f7db94e19ffd42b83 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/abinet/model_language.py @@ -0,0 +1,49 @@ +import torch.nn as nn + +from .model import Model +from .transformer import PositionalEncoding, TransformerDecoderLayer, TransformerDecoder + + +class BCNLanguage(Model): + def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, nhead=8, d_inner=2048, dropout=0.1, + activation='relu', num_layers=4, detach=True, use_self_attn=False, loss_weight=1.0, + global_debug=False): + super().__init__(dataset_max_length, null_label) + self.detach = detach + self.loss_weight = loss_weight + self.proj = nn.Linear(num_classes, d_model, False) + self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length) + self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length) + decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout, + activation, self_attn=use_self_attn, debug=global_debug) + self.model = TransformerDecoder(decoder_layer, num_layers) + self.cls = nn.Linear(d_model, num_classes) + + def forward(self, tokens, lengths): + """ + Args: + tokens: (N, T, C) where T is length, N is batch size and C is classes number + lengths: (N,) + """ + if self.detach: + tokens = tokens.detach() + embed = self.proj(tokens) # (N, T, E) + embed = embed.permute(1, 0, 2) # (T, N, E) + embed = self.token_encoder(embed) # (T, N, E) + padding_mask = self._get_padding_mask(lengths, self.max_length) + + zeros = embed.new_zeros(*embed.shape) + qeury = self.pos_encoder(zeros) + location_mask = self._get_location_mask(self.max_length, tokens.device) + output = self.model(qeury, embed, + tgt_key_padding_mask=padding_mask, + memory_mask=location_mask, + memory_key_padding_mask=padding_mask) # (T, N, E) + output = output.permute(1, 0, 2) # (N, T, E) + + logits = self.cls(output) # (N, T, C) + pt_lengths = self._get_length(logits) + + res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths, + 'loss_weight': self.loss_weight, 'name': 'language'} + return res diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/model_vision.py b/IndicPhotoOCR/utils/strhub/models/abinet/model_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..bddb7d5f237854b81c388090e2e20fc26632c431 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/abinet/model_vision.py @@ -0,0 +1,45 @@ +from torch import nn + +from .attention import PositionAttention, Attention +from .backbone import ResTranformer +from .model import Model +from .resnet import resnet45 + + +class BaseVision(Model): + def __init__(self, dataset_max_length, null_label, num_classes, + attention='position', attention_mode='nearest', loss_weight=1.0, + d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', + backbone='transformer', backbone_ln=2): + super().__init__(dataset_max_length, null_label) + self.loss_weight = loss_weight + self.out_channels = d_model + + if backbone == 'transformer': + self.backbone = ResTranformer(d_model, nhead, d_inner, dropout, activation, backbone_ln) + else: + self.backbone = resnet45() + + if attention == 'position': + self.attention = PositionAttention( + max_length=self.max_length, + mode=attention_mode + ) + elif attention == 'attention': + self.attention = Attention( + max_length=self.max_length, + n_feature=8 * 32, + ) + else: + raise ValueError(f'invalid attention: {attention}') + + self.cls = nn.Linear(self.out_channels, num_classes) + + def forward(self, images): + features = self.backbone(images) # (N, E, H, W) + attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W) + logits = self.cls(attn_vecs) # (N, T, C) + pt_lengths = self._get_length(logits) + + return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths, + 'attn_scores': attn_scores, 'loss_weight': self.loss_weight, 'name': 'vision'} diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/resnet.py b/IndicPhotoOCR/utils/strhub/models/abinet/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..59bf38896987b3560e254e8037426d29bcdd5844 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/abinet/resnet.py @@ -0,0 +1,72 @@ +import math +from typing import Optional, Callable + +import torch.nn as nn +from torchvision.models import resnet + + +class BasicBlock(resnet.BasicBlock): + + def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, + groups: int = 1, base_width: int = 64, dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None) -> None: + super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer) + self.conv1 = resnet.conv1x1(inplanes, planes) + self.conv2 = resnet.conv3x3(planes, planes, stride) + + +class ResNet(nn.Module): + + def __init__(self, block, layers): + super().__init__() + self.inplanes = 32 + self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, 32, layers[0], stride=2) + self.layer2 = self._make_layer(block, 64, layers[1], stride=1) + self.layer3 = self._make_layer(block, 128, layers[2], stride=2) + self.layer4 = self._make_layer(block, 256, layers[3], stride=1) + self.layer5 = self._make_layer(block, 512, layers[4], stride=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.layer5(x) + return x + + +def resnet45(): + return ResNet(BasicBlock, [3, 4, 6, 6, 3]) diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/system.py b/IndicPhotoOCR/utils/strhub/models/abinet/system.py new file mode 100644 index 0000000000000000000000000000000000000000..f56e9d1dff021318095d28fb5eb99cace5371ecb --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/abinet/system.py @@ -0,0 +1,215 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.optim import AdamW +from torch.optim.lr_scheduler import OneCycleLR + +from pytorch_lightning.utilities.types import STEP_OUTPUT +from timm.optim.optim_factory import param_groups_weight_decay + +from strhub.models.base import CrossEntropySystem +from strhub.models.utils import init_weights + +from .model_abinet_iter import ABINetIterModel as Model + +log = logging.getLogger(__name__) + + +class ABINet(CrossEntropySystem): + + def __init__( + self, + charset_train: str, + charset_test: str, + max_label_length: int, + batch_size: int, + lr: float, + warmup_pct: float, + weight_decay: float, + iter_size: int, + d_model: int, + nhead: int, + d_inner: int, + dropout: float, + activation: str, + v_loss_weight: float, + v_attention: str, + v_attention_mode: str, + v_backbone: str, + v_num_layers: int, + l_loss_weight: float, + l_num_layers: int, + l_detach: bool, + l_use_self_attn: bool, + l_lr: float, + a_loss_weight: float, + lm_only: bool = False, + **kwargs, + ) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.scheduler = None + self.save_hyperparameters() + self.max_label_length = max_label_length + self.num_classes = len(self.tokenizer) - 2 # We don't predict nor + self.model = Model( + max_label_length, + self.eos_id, + self.num_classes, + iter_size, + d_model, + nhead, + d_inner, + dropout, + activation, + v_loss_weight, + v_attention, + v_attention_mode, + v_backbone, + v_num_layers, + l_loss_weight, + l_num_layers, + l_detach, + l_use_self_attn, + a_loss_weight, + ) + self.model.apply(init_weights) + # FIXME: doesn't support resumption from checkpoint yet + self._reset_alignment = True + self._reset_optimizers = True + self.l_lr = l_lr + self.lm_only = lm_only + # Train LM only. Freeze other submodels. + if lm_only: + self.l_lr = lr # for tuning + self.model.vision.requires_grad_(False) + self.model.alignment.requires_grad_(False) + + @property + def _pretraining(self): + # In the original work, VM was pretrained for 8 epochs while full model was trained for an additional 10 epochs. + total_steps = self.trainer.estimated_stepping_batches * self.trainer.accumulate_grad_batches + return self.global_step < (8 / (8 + 10)) * total_steps + + @torch.jit.ignore + def no_weight_decay(self): + return {'model.language.proj.weight'} + + def _add_weight_decay(self, model: nn.Module, skip_list=()): + if self.weight_decay: + return param_groups_weight_decay(model, self.weight_decay, skip_list) + else: + return [{'params': model.parameters()}] + + def configure_optimizers(self): + agb = self.trainer.accumulate_grad_batches + # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP. + lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.0 + lr = lr_scale * self.lr + l_lr = lr_scale * self.l_lr + params = [] + params.extend(self._add_weight_decay(self.model.vision)) + params.extend(self._add_weight_decay(self.model.alignment)) + # We use a different learning rate for the LM. + for p in self._add_weight_decay(self.model.language, ('proj.weight',)): + p['lr'] = l_lr + params.append(p) + max_lr = [p.get('lr', lr) for p in params] + optim = AdamW(params, lr) + self.scheduler = OneCycleLR( + optim, max_lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, cycle_momentum=False + ) + return {'optimizer': optim, 'lr_scheduler': {'scheduler': self.scheduler, 'interval': 'step'}} + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + logits = self.model.forward(images)[0]['logits'] + return logits[:, : max_length + 1] # truncate + + def calc_loss(self, targets, *res_lists) -> Tensor: + total_loss = 0 + for res_list in res_lists: + loss = 0 + if isinstance(res_list, dict): + res_list = [res_list] + for res in res_list: + logits = res['logits'].flatten(end_dim=1) + loss += F.cross_entropy(logits, targets.flatten(), ignore_index=self.pad_id) + loss /= len(res_list) + self.log('loss_' + res_list[0]['name'], loss) + total_loss += res_list[0]['loss_weight'] * loss + return total_loss + + def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: + if not self._pretraining and self._reset_optimizers: + log.info('Pretraining ends. Updating base LRs.') + self._reset_optimizers = False + # Make base_lr the same for all groups + base_lr = self.scheduler.base_lrs[0] # base_lr of group 0 - VM + self.scheduler.base_lrs = [base_lr] * len(self.scheduler.base_lrs) + + def _prepare_inputs_and_targets(self, labels): + # Use dummy label to ensure sequence length is constant. + dummy = ['0' * self.max_label_length] + targets = self.tokenizer.encode(dummy + list(labels), self.device)[1:] + targets = targets[:, 1:] # remove . Unused here. + # Inputs are padded with eos_id + inputs = torch.where(targets == self.pad_id, self.eos_id, targets) + inputs = F.one_hot(inputs, self.num_classes).float() + lengths = torch.as_tensor(list(map(len, labels)), device=self.device) + 1 # +1 for eos + return inputs, lengths, targets + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + inputs, lengths, targets = self._prepare_inputs_and_targets(labels) + if self.lm_only: + l_res = self.model.language(inputs, lengths) + loss = self.calc_loss(targets, l_res) + # Pretrain submodels independently first + elif self._pretraining: + # Vision + v_res = self.model.vision(images) + # Language + l_res = self.model.language(inputs, lengths) + # We also train the alignment model to 'satisfy' DDP requirements (all parameters should be used). + # We'll reset its parameters prior to joint training. + a_res = self.model.alignment(l_res['feature'].detach(), v_res['feature'].detach()) + loss = self.calc_loss(targets, v_res, l_res, a_res) + else: + # Reset alignment model's parameters once prior to full model training. + if self._reset_alignment: + log.info('Pretraining ends. Resetting alignment model.') + self._reset_alignment = False + self.model.alignment.apply(init_weights) + all_a_res, all_l_res, v_res = self.model.forward(images) + loss = self.calc_loss(targets, v_res, all_l_res, all_a_res) + self.log('loss', loss) + return loss + + def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]: + if self.lm_only: + inputs, lengths, targets = self._prepare_inputs_and_targets(labels) + l_res = self.model.language(inputs, lengths) + loss = self.calc_loss(targets, l_res) + loss_numel = (targets != self.pad_id).sum() + return l_res['logits'], loss, loss_numel + else: + return super().forward_logits_loss(images, labels) diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/transformer.py b/IndicPhotoOCR/utils/strhub/models/abinet/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..03ae4b13976ddc67dfb2e2bfd83885a823cf9ecb --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/abinet/transformer.py @@ -0,0 +1,198 @@ +import math + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.modules.transformer import _get_activation_fn, _get_clones + + +class TransformerDecoder(nn.Module): + r"""TransformerDecoder is a stack of N decoder layers + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = transformer_decoder(tgt, memory) + """ + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm=None): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, tgt, memory, memory2=None, tgt_mask=None, + memory_mask=None, memory_mask2=None, tgt_key_padding_mask=None, + memory_key_padding_mask=None, memory_key_padding_mask2=None): + # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor + r"""Pass the inputs (and mask) through the decoder layer in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + output = tgt + + for mod in self.layers: + output = mod(output, memory, memory2=memory2, tgt_mask=tgt_mask, + memory_mask=memory_mask, memory_mask2=memory_mask2, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + memory_key_padding_mask2=memory_key_padding_mask2) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoderLayer(nn.Module): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", self_attn=True, siamese=False, debug=False): + super().__init__() + self.has_self_attn, self.siamese = self_attn, siamese + self.debug = debug + if self.has_self_attn: + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.norm1 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + if self.siamese: + self.multihead_attn2 = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super().__setstate__(state) + + def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, + tgt_key_padding_mask=None, memory_key_padding_mask=None, + memory2=None, memory_mask2=None, memory_key_padding_mask2=None): + # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor + r"""Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + if self.has_self_attn: + tgt2, attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + if self.debug: self.attn = attn + tgt2, attn2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + if self.debug: self.attn2 = attn2 + + if self.siamese: + tgt3, attn3 = self.multihead_attn2(tgt, memory2, memory2, attn_mask=memory_mask2, + key_padding_mask=memory_key_padding_mask2) + tgt = tgt + self.dropout2(tgt3) + if self.debug: self.attn3 = attn3 + + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + + return tgt + + +class PositionalEncoding(nn.Module): + r"""Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + r"""Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + + x = x + self.pe[:x.size(0), :] + return self.dropout(x) diff --git a/IndicPhotoOCR/utils/strhub/models/base.py b/IndicPhotoOCR/utils/strhub/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..42d61efd6311445c604446ddae208bdef3d7b115 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/base.py @@ -0,0 +1,221 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional + +from nltk import edit_distance + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.optim import Optimizer +from torch.optim.lr_scheduler import OneCycleLR + +import pytorch_lightning as pl +from pytorch_lightning.utilities.types import STEP_OUTPUT +from timm.optim import create_optimizer_v2 + +from IndicPhotoOCR.utils.strhub.data.utils import BaseTokenizer, CharsetAdapter, CTCTokenizer, Tokenizer + + +@dataclass +class BatchResult: + num_samples: int + correct: int + ned: float + confidence: float + label_length: int + loss: Tensor + loss_numel: int + + +EPOCH_OUTPUT = list[dict[str, BatchResult]] + + +class BaseSystem(pl.LightningModule, ABC): + + def __init__( + self, + tokenizer: BaseTokenizer, + charset_test: str, + batch_size: int, + lr: float, + warmup_pct: float, + weight_decay: float, + ) -> None: + super().__init__() + self.tokenizer = tokenizer + self.charset_adapter = CharsetAdapter(charset_test) + self.batch_size = batch_size + self.lr = lr + self.warmup_pct = warmup_pct + self.weight_decay = weight_decay + self.outputs: EPOCH_OUTPUT = [] + + @abstractmethod + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + """Inference + + Args: + images: Batch of images. Shape: N, Ch, H, W + max_length: Max sequence length of the output. If None, will use default. + + Returns: + logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials) + """ + raise NotImplementedError + + @abstractmethod + def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]: + """Like forward(), but also computes the loss (calls forward() internally). + + Args: + images: Batch of images. Shape: N, Ch, H, W + labels: Text labels of the images + + Returns: + logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials) + loss: mean loss for the batch + loss_numel: number of elements the loss was calculated from + """ + raise NotImplementedError + + def configure_optimizers(self): + agb = self.trainer.accumulate_grad_batches + # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP. + lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.0 + lr = lr_scale * self.lr + optim = create_optimizer_v2(self, 'adamw', lr, self.weight_decay) + sched = OneCycleLR( + optim, lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, cycle_momentum=False + ) + return {'optimizer': optim, 'lr_scheduler': {'scheduler': sched, 'interval': 'step'}} + + def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer) -> None: + optimizer.zero_grad(set_to_none=True) + + def _eval_step(self, batch, validation: bool) -> Optional[STEP_OUTPUT]: + images, labels = batch + + correct = 0 + total = 0 + ned = 0 + confidence = 0 + label_length = 0 + if validation: + logits, loss, loss_numel = self.forward_logits_loss(images, labels) + else: + # At test-time, we shouldn't specify a max_label_length because the test-time charset used + # might be different from the train-time charset. max_label_length in eval_logits_loss() is computed + # based on the transformed label, which could be wrong if the actual gt label contains characters existing + # in the train-time charset but not in the test-time charset. For example, "aishahaleyes.blogspot.com" + # is exactly 25 characters, but if processed by CharsetAdapter for the 36-char set, it becomes 23 characters + # long only, which sets max_label_length = 23. This will cause the model prediction to be truncated. + logits = self.forward(images) + loss = loss_numel = None # Only used for validation; not needed at test-time. + + probs = logits.softmax(-1) + preds, probs = self.tokenizer.decode(probs) + for pred, prob, gt in zip(preds, probs, labels): + confidence += prob.prod().item() + pred = self.charset_adapter(pred) + # Follow ICDAR 2019 definition of N.E.D. + ned += edit_distance(pred, gt) / max(len(pred), len(gt)) + if pred == gt: + correct += 1 + total += 1 + label_length += len(pred) + return dict(output=BatchResult(total, correct, ned, confidence, label_length, loss, loss_numel)) + + @staticmethod + def _aggregate_results(outputs: EPOCH_OUTPUT) -> tuple[float, float, float]: + if not outputs: + return 0.0, 0.0, 0.0 + total_loss = 0 + total_loss_numel = 0 + total_n_correct = 0 + total_norm_ED = 0 + total_size = 0 + for result in outputs: + result = result['output'] + total_loss += result.loss_numel * result.loss + total_loss_numel += result.loss_numel + total_n_correct += result.correct + total_norm_ED += result.ned + total_size += result.num_samples + acc = total_n_correct / total_size + ned = 1 - total_norm_ED / total_size + loss = total_loss / total_loss_numel + return acc, ned, loss + + def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]: + result = self._eval_step(batch, True) + self.outputs.append(result) + return result + + def on_validation_epoch_end(self) -> None: + acc, ned, loss = self._aggregate_results(self.outputs) + self.outputs.clear() + self.log('val_accuracy', 100 * acc, sync_dist=True) + self.log('val_NED', 100 * ned, sync_dist=True) + self.log('val_loss', loss, sync_dist=True) + self.log('hp_metric', acc, sync_dist=True) + + def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]: + return self._eval_step(batch, False) + + +class CrossEntropySystem(BaseSystem): + + def __init__( + self, charset_train: str, charset_test: str, batch_size: int, lr: float, warmup_pct: float, weight_decay: float + ) -> None: + tokenizer = Tokenizer(charset_train) + super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.bos_id = tokenizer.bos_id + self.eos_id = tokenizer.eos_id + self.pad_id = tokenizer.pad_id + + def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]: + targets = self.tokenizer.encode(labels, self.device) + targets = targets[:, 1:] # Discard + max_len = targets.shape[1] - 1 # exclude from count + logits = self.forward(images, max_len) + loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id) + loss_numel = (targets != self.pad_id).sum() + return logits, loss, loss_numel + + +class CTCSystem(BaseSystem): + + def __init__( + self, charset_train: str, charset_test: str, batch_size: int, lr: float, warmup_pct: float, weight_decay: float + ) -> None: + tokenizer = CTCTokenizer(charset_train) + super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.blank_id = tokenizer.blank_id + + def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]: + targets = self.tokenizer.encode(labels, self.device) + logits = self.forward(images) + log_probs = logits.log_softmax(-1).transpose(0, 1) # swap batch and seq. dims + T, N, _ = log_probs.shape + input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long, device=self.device) + target_lengths = torch.as_tensor(list(map(len, labels)), dtype=torch.long, device=self.device) + loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_id, zero_infinity=True) + return logits, loss, N diff --git a/IndicPhotoOCR/utils/strhub/models/crnn/LICENSE b/IndicPhotoOCR/utils/strhub/models/crnn/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f98687be392fdce266708e79885aadaa4991b67f --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/crnn/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2017 Jieru Mei + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/IndicPhotoOCR/utils/strhub/models/crnn/__init__.py b/IndicPhotoOCR/utils/strhub/models/crnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4535947d9233c8fb0a85e9c22b151697d37f410 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/crnn/__init__.py @@ -0,0 +1,13 @@ +r""" +Shi, Baoguang, Xiang Bai, and Cong Yao. +"An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition." +IEEE transactions on pattern analysis and machine intelligence 39, no. 11 (2016): 2298-2304. + +https://arxiv.org/abs/1507.05717 + +All source files, except `system.py`, are based on the implementation listed below, +and hence are released under the license of the original. + +Source: https://github.com/meijieru/crnn.pytorch +License: MIT License (see included LICENSE file) +""" diff --git a/IndicPhotoOCR/utils/strhub/models/crnn/model.py b/IndicPhotoOCR/utils/strhub/models/crnn/model.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5c9e8e6a1a2f3d4ed32c976f47a8cbdff22946 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/crnn/model.py @@ -0,0 +1,62 @@ +import torch.nn as nn + +from strhub.models.modules import BidirectionalLSTM + + +class CRNN(nn.Module): + + def __init__(self, img_h, nc, nclass, nh, leaky_relu=False): + super().__init__() + assert img_h % 16 == 0, 'img_h has to be a multiple of 16' + + ks = [3, 3, 3, 3, 3, 3, 2] + ps = [1, 1, 1, 1, 1, 1, 0] + ss = [1, 1, 1, 1, 1, 1, 1] + nm = [64, 128, 256, 256, 512, 512, 512] + + cnn = nn.Sequential() + + def convRelu(i, batchNormalization=False): + nIn = nc if i == 0 else nm[i - 1] + nOut = nm[i] + cnn.add_module(f'conv{i}', + nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i], bias=not batchNormalization)) + if batchNormalization: + cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(nOut)) + if leaky_relu: + cnn.add_module(f'relu{i}', + nn.LeakyReLU(0.2, inplace=True)) + else: + cnn.add_module(f'relu{i}', nn.ReLU(True)) + + convRelu(0) + cnn.add_module('pooling0', nn.MaxPool2d(2, 2)) # 64x16x64 + convRelu(1) + cnn.add_module('pooling1', nn.MaxPool2d(2, 2)) # 128x8x32 + convRelu(2, True) + convRelu(3) + cnn.add_module('pooling2', + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 + convRelu(4, True) + convRelu(5) + cnn.add_module('pooling3', + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 + convRelu(6, True) # 512x1x16 + + self.cnn = cnn + self.rnn = nn.Sequential( + BidirectionalLSTM(512, nh, nh), + BidirectionalLSTM(nh, nh, nclass)) + + def forward(self, input): + # conv features + conv = self.cnn(input) + b, c, h, w = conv.size() + assert h == 1, 'the height of conv must be 1' + conv = conv.squeeze(2) + conv = conv.transpose(1, 2) # [b, w, c] + + # rnn features + output = self.rnn(conv) + + return output diff --git a/IndicPhotoOCR/utils/strhub/models/crnn/system.py b/IndicPhotoOCR/utils/strhub/models/crnn/system.py new file mode 100644 index 0000000000000000000000000000000000000000..a69dfdd131cc3895bfc7b0aaa0832681dbfaab25 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/crnn/system.py @@ -0,0 +1,56 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence + +from torch import Tensor + +from pytorch_lightning.utilities.types import STEP_OUTPUT + +from strhub.models.base import CTCSystem +from strhub.models.utils import init_weights + +from .model import CRNN as Model + + +class CRNN(CTCSystem): + + def __init__( + self, + charset_train: str, + charset_test: str, + max_label_length: int, + batch_size: int, + lr: float, + warmup_pct: float, + weight_decay: float, + img_size: Sequence[int], + hidden_size: int, + leaky_relu: bool, + **kwargs, + ) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + self.model = Model(img_size[0], 3, len(self.tokenizer), hidden_size, leaky_relu) + self.model.apply(init_weights) + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + return self.model.forward(images) + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + loss = self.forward_logits_loss(images, labels)[1] + self.log('loss', loss) + return loss diff --git a/IndicPhotoOCR/utils/strhub/models/modules.py b/IndicPhotoOCR/utils/strhub/models/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..a89d05f6afd67437f3cfa8aff6d2d8b12df3fafa --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/modules.py @@ -0,0 +1,20 @@ +r"""Shared modules used by CRNN and TRBA""" +from torch import nn + + +class BidirectionalLSTM(nn.Module): + """Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py""" + + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) + self.linear = nn.Linear(hidden_size * 2, output_size) + + def forward(self, input): + """ + input : visual feature [batch_size x T x input_size], T = num_steps. + output : contextual feature [batch_size x T x output_size] + """ + recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) + output = self.linear(recurrent) # batch_size x T x output_size + return output diff --git a/IndicPhotoOCR/utils/strhub/models/parseq/__init__.py b/IndicPhotoOCR/utils/strhub/models/parseq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/IndicPhotoOCR/utils/strhub/models/parseq/model.py b/IndicPhotoOCR/utils/strhub/models/parseq/model.py new file mode 100644 index 0000000000000000000000000000000000000000..28f4f3d06d7ccc005cb40e47a5e86626a54aa04d --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/parseq/model.py @@ -0,0 +1,169 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Optional, Sequence + +import torch +import torch.nn as nn +from torch import Tensor + +from timm.models.helpers import named_apply + +from IndicPhotoOCR.utils.strhub.data.utils import Tokenizer +from IndicPhotoOCR.utils.strhub.models.utils import init_weights + +from .modules import Decoder, DecoderLayer, Encoder, TokenEmbedding + + +class PARSeq(nn.Module): + + def __init__( + self, + num_tokens: int, + max_label_length: int, + img_size: Sequence[int], + patch_size: Sequence[int], + embed_dim: int, + enc_num_heads: int, + enc_mlp_ratio: int, + enc_depth: int, + dec_num_heads: int, + dec_mlp_ratio: int, + dec_depth: int, + decode_ar: bool, + refine_iters: int, + dropout: float, + ) -> None: + super().__init__() + + self.max_label_length = max_label_length + self.decode_ar = decode_ar + self.refine_iters = refine_iters + + self.encoder = Encoder( + img_size, patch_size, embed_dim=embed_dim, depth=enc_depth, num_heads=enc_num_heads, mlp_ratio=enc_mlp_ratio + ) + decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) + self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=nn.LayerNorm(embed_dim)) + + # We don't predict nor + self.head = nn.Linear(embed_dim, num_tokens - 2) + self.text_embed = TokenEmbedding(num_tokens, embed_dim) + + # +1 for + self.pos_queries = nn.Parameter(torch.Tensor(1, max_label_length + 1, embed_dim)) + self.dropout = nn.Dropout(p=dropout) + # Encoder has its own init. + named_apply(partial(init_weights, exclude=['encoder']), self) + nn.init.trunc_normal_(self.pos_queries, std=0.02) + + @property + def _device(self) -> torch.device: + return next(self.head.parameters(recurse=False)).device + + @torch.jit.ignore + def no_weight_decay(self): + param_names = {'text_embed.embedding.weight', 'pos_queries'} + enc_param_names = {'encoder.' + n for n in self.encoder.no_weight_decay()} + return param_names.union(enc_param_names) + + def encode(self, img: torch.Tensor): + return self.encoder(img) + + def decode( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[Tensor] = None, + tgt_padding_mask: Optional[Tensor] = None, + tgt_query: Optional[Tensor] = None, + tgt_query_mask: Optional[Tensor] = None, + ): + N, L = tgt.shape + # stands for the null context. We only supply position information for characters after . + null_ctx = self.text_embed(tgt[:, :1]) + tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:]) + tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1)) + if tgt_query is None: + tgt_query = self.pos_queries[:, :L].expand(N, -1, -1) + tgt_query = self.dropout(tgt_query) + return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask) + + def forward(self, tokenizer: Tokenizer, images: Tensor, max_length: Optional[int] = None) -> Tensor: + testing = max_length is None + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + bs = images.shape[0] + # +1 for at end of sequence. + num_steps = max_length + 1 + memory = self.encode(images) + + # Query positions up to `num_steps` + pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1) + + # Special case for the forward permutation. Faster than using `generate_attn_masks()` + tgt_mask = query_mask = torch.triu(torch.ones((num_steps, num_steps), dtype=torch.bool, device=self._device), 1) + + if self.decode_ar: + tgt_in = torch.full((bs, num_steps), tokenizer.pad_id, dtype=torch.long, device=self._device) + tgt_in[:, 0] = tokenizer.bos_id + + logits = [] + for i in range(num_steps): + j = i + 1 # next token index + # Efficient decoding: + # Input the context up to the ith token. We use only one query (at position = i) at a time. + # This works because of the lookahead masking effect of the canonical (forward) AR context. + # Past tokens have no access to future tokens, hence are fixed once computed. + tgt_out = self.decode( + tgt_in[:, :j], + memory, + tgt_mask[:j, :j], + tgt_query=pos_queries[:, i:j], + tgt_query_mask=query_mask[i:j, :j], + ) + # the next token probability is in the output's ith token position + p_i = self.head(tgt_out) + logits.append(p_i) + if j < num_steps: + # greedy decode. add the next token index to the target input + tgt_in[:, j] = p_i.squeeze().argmax(-1) + # Efficient batch decoding: If all output words have at least one EOS token, end decoding. + if testing and (tgt_in == tokenizer.eos_id).any(dim=-1).all(): + break + + logits = torch.cat(logits, dim=1) + else: + # No prior context, so input is just . We query all positions. + tgt_in = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device) + tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries) + logits = self.head(tgt_out) + + if self.refine_iters: + # For iterative refinement, we always use a 'cloze' mask. + # We can derive it from the AR forward mask by unmasking the token context to the right. + query_mask[torch.triu(torch.ones(num_steps, num_steps, dtype=torch.bool, device=self._device), 2)] = 0 + bos = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device) + for i in range(self.refine_iters): + # Prior context is the previous output. + tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1) + # Mask tokens beyond the first EOS token. + tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(-1) > 0 + tgt_out = self.decode( + tgt_in, memory, tgt_mask, tgt_padding_mask, pos_queries, query_mask[:, : tgt_in.shape[1]] + ) + logits = self.head(tgt_out) + + return logits diff --git a/IndicPhotoOCR/utils/strhub/models/parseq/modules.py b/IndicPhotoOCR/utils/strhub/models/parseq/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..e3e3f23f6ee52c9de8b21df63efe7299100eb44d --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/parseq/modules.py @@ -0,0 +1,176 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +from torch import Tensor, nn as nn +from torch.nn import functional as F +from torch.nn.modules import transformer + +from timm.models.vision_transformer import PatchEmbed, VisionTransformer + + +class DecoderLayer(nn.Module): + """A Transformer decoder layer supporting two-stream attention (XLNet) + This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch.""" + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', layer_norm_eps=1e-5): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) + self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = transformer._get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.gelu + super().__setstate__(state) + + def forward_stream( + self, + tgt: Tensor, + tgt_norm: Tensor, + tgt_kv: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor], + tgt_key_padding_mask: Optional[Tensor], + ): + """Forward pass for a single stream (i.e. content or query) + tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency. + Both tgt_kv and memory are expected to be LayerNorm'd too. + memory is LayerNorm'd by ViT. + """ + tgt2, sa_weights = self.self_attn( + tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + ) + tgt = tgt + self.dropout1(tgt2) + + tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory) + tgt = tgt + self.dropout2(tgt2) + + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt))))) + tgt = tgt + self.dropout3(tgt2) + return tgt, sa_weights, ca_weights + + def forward( + self, + query, + content, + memory, + query_mask: Optional[Tensor] = None, + content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None, + update_content: bool = True, + ): + query_norm = self.norm_q(query) + content_norm = self.norm_c(content) + query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0] + if update_content: + content = self.forward_stream( + content, content_norm, content_norm, memory, content_mask, content_key_padding_mask + )[0] + return query, content + + +class Decoder(nn.Module): + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm): + super().__init__() + self.layers = transformer._get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + query, + content, + memory, + query_mask: Optional[Tensor] = None, + content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None, + ): + for i, mod in enumerate(self.layers): + last = i == len(self.layers) - 1 + query, content = mod( + query, content, memory, query_mask, content_mask, content_key_padding_mask, update_content=not last + ) + query = self.norm(query) + return query + + +class Encoder(VisionTransformer): + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + embed_layer=PatchEmbed, + ): + super().__init__( + img_size, + patch_size, + in_chans, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + embed_layer=embed_layer, + num_classes=0, # These + global_pool='', # disable the + class_token=False, # classifier head. + ) + + def forward(self, x): + # Return all tokens + return self.forward_features(x) + + +class TokenEmbedding(nn.Module): + + def __init__(self, charset_size: int, embed_dim: int): + super().__init__() + self.embedding = nn.Embedding(charset_size, embed_dim) + self.embed_dim = embed_dim + + def forward(self, tokens: torch.Tensor): + return math.sqrt(self.embed_dim) * self.embedding(tokens) diff --git a/IndicPhotoOCR/utils/strhub/models/parseq/system.py b/IndicPhotoOCR/utils/strhub/models/parseq/system.py new file mode 100644 index 0000000000000000000000000000000000000000..217275f8345f84ce88be97fb8932a30e51baf6ca --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/parseq/system.py @@ -0,0 +1,200 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from itertools import permutations +from typing import Any, Optional, Sequence + +import numpy as np + +import torch +import torch.nn.functional as F +from torch import Tensor + +from pytorch_lightning.utilities.types import STEP_OUTPUT + +from IndicPhotoOCR.utils.strhub.models.base import CrossEntropySystem + +from .model import PARSeq as Model + + +class PARSeq(CrossEntropySystem): + + def __init__( + self, + charset_train: str, + charset_test: str, + max_label_length: int, + batch_size: int, + lr: float, + warmup_pct: float, + weight_decay: float, + img_size: Sequence[int], + patch_size: Sequence[int], + embed_dim: int, + enc_num_heads: int, + enc_mlp_ratio: int, + enc_depth: int, + dec_num_heads: int, + dec_mlp_ratio: int, + dec_depth: int, + perm_num: int, + perm_forward: bool, + perm_mirrored: bool, + decode_ar: bool, + refine_iters: int, + dropout: float, + **kwargs: Any, + ) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + + self.model = Model( + len(self.tokenizer), + max_label_length, + img_size, + patch_size, + embed_dim, + enc_num_heads, + enc_mlp_ratio, + enc_depth, + dec_num_heads, + dec_mlp_ratio, + dec_depth, + decode_ar, + refine_iters, + dropout, + ) + + # Perm/attn mask stuff + self.rng = np.random.default_rng() + self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num + self.perm_forward = perm_forward + self.perm_mirrored = perm_mirrored + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + return self.model.forward(self.tokenizer, images, max_length) + + def gen_tgt_perms(self, tgt): + """Generate shared permutations for the whole batch. + This works because the same attention mask can be used for the shorter sequences + because of the padding mask. + """ + # We don't permute the position of BOS, we permute EOS separately + max_num_chars = tgt.shape[1] - 2 + # Special handling for 1-character sequences + if max_num_chars == 1: + return torch.arange(3, device=self._device).unsqueeze(0) + perms = [torch.arange(max_num_chars, device=self._device)] if self.perm_forward else [] + # Additional permutations if needed + max_perms = math.factorial(max_num_chars) + if self.perm_mirrored: + max_perms //= 2 + num_gen_perms = min(self.max_gen_perms, max_perms) + # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions + # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars. + if max_num_chars < 5: + # Pool of permutations to sample from. We only need the first half (if complementary option is selected) + # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves + if max_num_chars == 4 and self.perm_mirrored: + selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21] + else: + selector = list(range(max_perms)) + perm_pool = torch.as_tensor( + list(permutations(range(max_num_chars), max_num_chars)), + device=self._device, + )[selector] + # If the forward permutation is always selected, no need to add it to the pool for sampling + if self.perm_forward: + perm_pool = perm_pool[1:] + perms = torch.stack(perms) + if len(perm_pool): + i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(perms), replace=False) + perms = torch.cat([perms, perm_pool[i]]) + else: + perms.extend( + [torch.randperm(max_num_chars, device=self._device) for _ in range(num_gen_perms - len(perms))] + ) + perms = torch.stack(perms) + if self.perm_mirrored: + # Add complementary pairs + comp = perms.flip(-1) + # Stack in such a way that the pairs are next to each other. + perms = torch.stack([perms, comp]).transpose(0, 1).reshape(-1, max_num_chars) + # NOTE: + # The only meaningful way of permuting the EOS position is by moving it one character position at a time. + # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS + # positions will always be much less than the number of permutations (unless a low perm_num is set). + # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly + # distribute it across the chosen number of permutations. + # Add position indices of BOS and EOS + bos_idx = perms.new_zeros((len(perms), 1)) + eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1) + perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1) + # Special handling for the reverse direction. This does two things: + # 1. Reverse context for the characters + # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode) + if len(perms) > 1: + perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=self._device) + return perms + + def generate_attn_masks(self, perm): + """Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens) + :param perm: the permutation sequence. i = 0 is always the BOS + :return: lookahead attention masks + """ + sz = perm.shape[0] + mask = torch.zeros((sz, sz), dtype=torch.bool, device=self._device) + for i in range(sz): + query_idx = perm[i] + masked_keys = perm[i + 1 :] + mask[query_idx, masked_keys] = True + content_mask = mask[:-1, :-1].clone() + mask[torch.eye(sz, dtype=torch.bool, device=self._device)] = True # mask "self" + query_mask = mask[1:, :-1] + return content_mask, query_mask + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + tgt = self.tokenizer.encode(labels, self._device) + + # Encode the source sequence (i.e. the image codes) + memory = self.model.encode(images) + + # Prepare the target sequences (input and output) + tgt_perms = self.gen_tgt_perms(tgt) + tgt_in = tgt[:, :-1] + tgt_out = tgt[:, 1:] + # The [EOS] token is not depended upon by any other token in any permutation ordering + tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id) + + loss = 0 + loss_numel = 0 + n = (tgt_out != self.pad_id).sum().item() + for i, perm in enumerate(tgt_perms): + tgt_mask, query_mask = self.generate_attn_masks(perm) + out = self.model.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask) + logits = self.model.head(out).flatten(end_dim=1) + loss += n * F.cross_entropy(logits, tgt_out.flatten(), ignore_index=self.pad_id) + loss_numel += n + # After the second iteration (i.e. done with canonical and reverse orderings), + # remove the [EOS] tokens for the succeeding perms + if i == 1: + tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id, tgt_out) + n = (tgt_out != self.pad_id).sum().item() + loss /= loss_numel + + self.log('loss', loss) + return loss diff --git a/IndicPhotoOCR/utils/strhub/models/trba/__init__.py b/IndicPhotoOCR/utils/strhub/models/trba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a574a8af95e7f1ffaa05c45b4cd22f4a3cc0a5c0 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/trba/__init__.py @@ -0,0 +1,13 @@ +r""" +Baek, Jeonghun, Geewook Kim, Junyeop Lee, Sungrae Park, Dongyoon Han, Sangdoo Yun, Seong Joon Oh, and Hwalsuk Lee. +"What is wrong with scene text recognition model comparisons? dataset and model analysis." +In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4715-4723. 2019. + +https://arxiv.org/abs/1904.01906 + +All source files, except `system.py`, are based on the implementation listed below, +and hence are released under the license of the original. + +Source: https://github.com/clovaai/deep-text-recognition-benchmark +License: Apache License 2.0 (see LICENSE file in project root) +""" diff --git a/IndicPhotoOCR/utils/strhub/models/trba/feature_extraction.py b/IndicPhotoOCR/utils/strhub/models/trba/feature_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..17646e3ff83ad28c1021237824a838e38c3b6345 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/trba/feature_extraction.py @@ -0,0 +1,110 @@ +import torch.nn as nn + +from torchvision.models.resnet import BasicBlock + + +class ResNet_FeatureExtractor(nn.Module): + """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super().__init__() + self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) + + def forward(self, input): + return self.ConvNet(input) + + +class ResNet(nn.Module): + + def __init__(self, input_channel, output_channel, block, layers): + super().__init__() + + self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] + + self.inplanes = int(output_channel / 8) + self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) + self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_2 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU(inplace=True) + + self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) + self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ + 0], kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) + + self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) + self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ + 1], kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) + + self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) + self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) + self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ + 2], kernel_size=3, stride=1, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) + + self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) + self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) + self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) + self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=1, padding=0, bias=False) + self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv0_1(x) + x = self.bn0_1(x) + x = self.relu(x) + x = self.conv0_2(x) + x = self.bn0_2(x) + x = self.relu(x) + + x = self.maxpool1(x) + x = self.layer1(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.maxpool2(x) + x = self.layer2(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.maxpool3(x) + x = self.layer3(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.layer4(x) + x = self.conv4_1(x) + x = self.bn4_1(x) + x = self.relu(x) + x = self.conv4_2(x) + x = self.bn4_2(x) + x = self.relu(x) + + return x diff --git a/IndicPhotoOCR/utils/strhub/models/trba/model.py b/IndicPhotoOCR/utils/strhub/models/trba/model.py new file mode 100644 index 0000000000000000000000000000000000000000..41161a4df4e2ff368bfe1c62f681c6964510a0c0 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/trba/model.py @@ -0,0 +1,55 @@ +import torch.nn as nn + +from strhub.models.modules import BidirectionalLSTM +from .feature_extraction import ResNet_FeatureExtractor +from .prediction import Attention +from .transformation import TPS_SpatialTransformerNetwork + + +class TRBA(nn.Module): + + def __init__(self, img_h, img_w, num_class, num_fiducial=20, input_channel=3, output_channel=512, hidden_size=256, + use_ctc=False): + super().__init__() + """ Transformation """ + self.Transformation = TPS_SpatialTransformerNetwork( + F=num_fiducial, I_size=(img_h, img_w), I_r_size=(img_h, img_w), + I_channel_num=input_channel) + + """ FeatureExtraction """ + self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel) + self.FeatureExtraction_output = output_channel + self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 + + """ Sequence modeling""" + self.SequenceModeling = nn.Sequential( + BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size), + BidirectionalLSTM(hidden_size, hidden_size, hidden_size)) + self.SequenceModeling_output = hidden_size + + """ Prediction """ + if use_ctc: + self.Prediction = nn.Linear(self.SequenceModeling_output, num_class) + else: + self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class) + + def forward(self, image, max_label_length, text=None): + """ Transformation stage """ + image = self.Transformation(image) + + """ Feature extraction stage """ + visual_feature = self.FeatureExtraction(image) + visual_feature = visual_feature.permute(0, 3, 1, 2) # [b, c, h, w] -> [b, w, c, h] + visual_feature = self.AdaptiveAvgPool(visual_feature) # [b, w, c, h] -> [b, w, c, 1] + visual_feature = visual_feature.squeeze(3) # [b, w, c, 1] -> [b, w, c] + + """ Sequence modeling stage """ + contextual_feature = self.SequenceModeling(visual_feature) # [b, num_steps, hidden_size] + + """ Prediction stage """ + if isinstance(self.Prediction, Attention): + prediction = self.Prediction(contextual_feature.contiguous(), text, max_label_length) + else: + prediction = self.Prediction(contextual_feature.contiguous()) # CTC + + return prediction # [b, num_steps, num_class] diff --git a/IndicPhotoOCR/utils/strhub/models/trba/prediction.py b/IndicPhotoOCR/utils/strhub/models/trba/prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..5609398a28ef5288d3f3971786c2cebc2e574336 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/trba/prediction.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Attention(nn.Module): + + def __init__(self, input_size, hidden_size, num_class, num_char_embeddings=256): + super().__init__() + self.attention_cell = AttentionCell(input_size, hidden_size, num_char_embeddings) + self.hidden_size = hidden_size + self.num_class = num_class + self.generator = nn.Linear(hidden_size, num_class) + self.char_embeddings = nn.Embedding(num_class, num_char_embeddings) + + def forward(self, batch_H, text, max_label_length=25): + """ + input: + batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_class] + text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [SOS] token. text[:, 0] = [SOS]. + output: probability distribution at each step [batch_size x num_steps x num_class] + """ + batch_size = batch_H.size(0) + num_steps = max_label_length + 1 # +1 for [EOS] at end of sentence. + + output_hiddens = batch_H.new_zeros((batch_size, num_steps, self.hidden_size), dtype=torch.float) + hidden = (batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float), + batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float)) + + if self.training: + for i in range(num_steps): + char_embeddings = self.char_embeddings(text[:, i]) + # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_embeddings : f(y_{t-1}) + hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings) + output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) + probs = self.generator(output_hiddens) + + else: + targets = text[0].expand(batch_size) # should be fill with [SOS] token + probs = batch_H.new_zeros((batch_size, num_steps, self.num_class), dtype=torch.float) + + for i in range(num_steps): + char_embeddings = self.char_embeddings(targets) + hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings) + probs_step = self.generator(hidden[0]) + probs[:, i, :] = probs_step + _, next_input = probs_step.max(1) + targets = next_input + + return probs # batch_size x num_steps x num_class + + +class AttentionCell(nn.Module): + + def __init__(self, input_size, hidden_size, num_embeddings): + super().__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias=False) + self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias + self.score = nn.Linear(hidden_size, 1, bias=False) + self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_embeddings): + # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) + e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 + + alpha = F.softmax(e, dim=1) + context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel + concat_context = torch.cat([context, char_embeddings], 1) # batch_size x (num_channel + num_embedding) + cur_hidden = self.rnn(concat_context, prev_hidden) + return cur_hidden, alpha diff --git a/IndicPhotoOCR/utils/strhub/models/trba/system.py b/IndicPhotoOCR/utils/strhub/models/trba/system.py new file mode 100644 index 0000000000000000000000000000000000000000..eabc5bacf3ef0a61d6b50cb1707c7da9eb2cb930 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/trba/system.py @@ -0,0 +1,125 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Any, Optional, Sequence + +import torch +import torch.nn.functional as F +from torch import Tensor + +from pytorch_lightning.utilities.types import STEP_OUTPUT +from timm.models.helpers import named_apply + +from strhub.models.base import CrossEntropySystem, CTCSystem +from strhub.models.utils import init_weights + +from .model import TRBA as Model + + +class TRBA(CrossEntropySystem): + + def __init__( + self, + charset_train: str, + charset_test: str, + max_label_length: int, + batch_size: int, + lr: float, + warmup_pct: float, + weight_decay: float, + img_size: Sequence[int], + num_fiducial: int, + output_channel: int, + hidden_size: int, + **kwargs: Any, + ) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + self.max_label_length = max_label_length + img_h, img_w = img_size + self.model = Model( + img_h, + img_w, + len(self.tokenizer), + num_fiducial, + output_channel=output_channel, + hidden_size=hidden_size, + use_ctc=False, + ) + named_apply(partial(init_weights, exclude=['Transformation.LocalizationNetwork.localization_fc2']), self.model) + + @torch.jit.ignore + def no_weight_decay(self): + return {'model.Prediction.char_embeddings.weight'} + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + text = images.new_full([1], self.bos_id, dtype=torch.long) + return self.model.forward(images, max_length, text) + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + encoded = self.tokenizer.encode(labels, self.device) + inputs = encoded[:, :-1] # remove + targets = encoded[:, 1:] # remove + max_length = encoded.shape[1] - 2 # exclude and from count + logits = self.model.forward(images, max_length, inputs) + loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id) + self.log('loss', loss) + return loss + + +class TRBC(CTCSystem): + + def __init__( + self, + charset_train: str, + charset_test: str, + max_label_length: int, + batch_size: int, + lr: float, + warmup_pct: float, + weight_decay: float, + img_size: Sequence[int], + num_fiducial: int, + output_channel: int, + hidden_size: int, + **kwargs: Any, + ) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + self.max_label_length = max_label_length + img_h, img_w = img_size + self.model = Model( + img_h, + img_w, + len(self.tokenizer), + num_fiducial, + output_channel=output_channel, + hidden_size=hidden_size, + use_ctc=True, + ) + named_apply(partial(init_weights, exclude=['Transformation.LocalizationNetwork.localization_fc2']), self.model) + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + # max_label_length is unused in CTC prediction + return self.model.forward(images, None) + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + loss = self.forward_logits_loss(images, labels)[1] + self.log('loss', loss) + return loss diff --git a/IndicPhotoOCR/utils/strhub/models/trba/transformation.py b/IndicPhotoOCR/utils/strhub/models/trba/transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..960419d135ec878aaaa3297c3ff5c22e998ef6be --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/trba/transformation.py @@ -0,0 +1,169 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TPS_SpatialTransformerNetwork(nn.Module): + """ Rectification Network of RARE, namely TPS based STN """ + + def __init__(self, F, I_size, I_r_size, I_channel_num=1): + """ Based on RARE TPS + input: + batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] + I_size : (height, width) of the input image I + I_r_size : (height, width) of the rectified image I_r + I_channel_num : the number of channels of the input image I + output: + batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] + """ + super().__init__() + self.F = F + self.I_size = I_size + self.I_r_size = I_r_size # = (I_r_height, I_r_width) + self.I_channel_num = I_channel_num + self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) + self.GridGenerator = GridGenerator(self.F, self.I_r_size) + + def forward(self, batch_I): + batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 + # batch_size x n (= I_r_width x I_r_height) x 2 + build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) + build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) + + if torch.__version__ > "1.2.0": + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) + else: + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') + + return batch_I_r + + +class LocalizationNetwork(nn.Module): + """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ + + def __init__(self, F, I_channel_num): + super().__init__() + self.F = F + self.I_channel_num = I_channel_num + self.conv = nn.Sequential( + nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, + bias=False), nn.BatchNorm2d(64), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 + nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 + nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 + nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), + nn.AdaptiveAvgPool2d(1) # batch_size x 512 + ) + + self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) + self.localization_fc2 = nn.Linear(256, self.F * 2) + + # Init fc2 in LocalizationNetwork + self.localization_fc2.weight.data.fill_(0) + """ see RARE paper Fig. 6 (a) """ + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) + ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) + + def forward(self, batch_I): + """ + input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] + output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] + """ + batch_size = batch_I.size(0) + features = self.conv(batch_I).view(batch_size, -1) + batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) + return batch_C_prime + + +class GridGenerator(nn.Module): + """ Grid Generator of RARE, which produces P_prime by multipling T with P """ + + def __init__(self, F, I_r_size): + """ Generate P_hat and inv_delta_C for later """ + super().__init__() + self.eps = 1e-6 + self.I_r_height, self.I_r_width = I_r_size + self.F = F + self.C = self._build_C(self.F) # F x 2 + self.P = self._build_P(self.I_r_width, self.I_r_height) + + # num_gpu = torch.cuda.device_count() + # if num_gpu > 1: + # for multi-gpu, you may need register buffer + self.register_buffer("inv_delta_C", torch.tensor( + self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 + self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 + # else: + # # for fine-tuning with different image width, you may use below instead of self.register_buffer + # self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float() # F+3 x F+3 + # self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float() # n x F+3 + + def _build_C(self, F): + """ Return coordinates of fiducial points in I_r; C """ + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = -1 * np.ones(int(F / 2)) + ctrl_pts_y_bottom = np.ones(int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + return C # F x 2 + + def _build_inv_delta_C(self, F, C): + """ Return inv_delta_C which is needed to calculate T """ + hat_C = np.zeros((F, F), dtype=float) # F x F + for i in range(0, F): + for j in range(i, F): + r = np.linalg.norm(C[i] - C[j]) + hat_C[i, j] = r + hat_C[j, i] = r + np.fill_diagonal(hat_C, 1) + hat_C = (hat_C ** 2) * np.log(hat_C) + # print(C.shape, hat_C.shape) + delta_C = np.concatenate( # F+3 x F+3 + [ + np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 + np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 + np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 + ], + axis=0 + ) + inv_delta_C = np.linalg.inv(delta_C) + return inv_delta_C # F+3 x F+3 + + def _build_P(self, I_r_width, I_r_height): + I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width + I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height + P = np.stack( # self.I_r_width x self.I_r_height x 2 + np.meshgrid(I_r_grid_x, I_r_grid_y), + axis=2 + ) + return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 + + def _build_P_hat(self, F, C, P): + n = P.shape[0] # n (= self.I_r_width x self.I_r_height) + P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 + C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 + P_diff = P_tile - C_tile # n x F x 2 + rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F + rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F + P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) + return P_hat # n x F+3 + + def build_P_prime(self, batch_C_prime): + """ Generate Grid from batch_C_prime [batch_size x F x 2] """ + batch_size = batch_C_prime.size(0) + batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) + batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) + batch_C_prime_with_zeros = torch.cat((batch_C_prime, batch_C_prime.new_zeros( + (batch_size, 3, 2), dtype=torch.float)), dim=1) # batch_size x F+3 x 2 + batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 + batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 + return batch_P_prime # batch_size x n x 2 diff --git a/IndicPhotoOCR/utils/strhub/models/utils.py b/IndicPhotoOCR/utils/strhub/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4debeb25999c2c10e105bb3e32d5ae6f8d04ed9d --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/utils.py @@ -0,0 +1,125 @@ +from pathlib import PurePath +from typing import Sequence + +import yaml + +import torch +from torch import nn + + +class InvalidModelError(RuntimeError): + """Exception raised for any model-related error (creation, loading)""" + + +_WEIGHTS_URL = { + 'parseq-tiny': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq_tiny-e7a21b54.pt', + 'parseq-patch16-224': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq_small_patch16_224-fcf06f5a.pt', + 'parseq': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq-bb5792a6.pt', + 'abinet': 'https://github.com/baudm/parseq/releases/download/v1.0.0/abinet-1d1e373e.pt', + 'trba': 'https://github.com/baudm/parseq/releases/download/v1.0.0/trba-cfaed284.pt', + 'vitstr': 'https://github.com/baudm/parseq/releases/download/v1.0.0/vitstr-26d0fcf4.pt', + 'crnn': 'https://github.com/baudm/parseq/releases/download/v1.0.0/crnn-679d0e31.pt', +} + + +def _get_config(experiment: str, **kwargs): + """Emulates hydra config resolution""" + root = PurePath(__file__).parents[2] + with open(root / 'configs/main.yaml', 'r') as f: + config = yaml.load(f, yaml.Loader)['model'] + with open(root / 'configs/charset/94_full.yaml', 'r') as f: + config.update(yaml.load(f, yaml.Loader)['model']) + with open(root / f'configs/experiment/{experiment}.yaml', 'r') as f: + exp = yaml.load(f, yaml.Loader) + # Apply base model config + model = exp['defaults'][0]['override /model'] + with open(root / f'configs/model/{model}.yaml', 'r') as f: + config.update(yaml.load(f, yaml.Loader)) + # Apply experiment config + if 'model' in exp: + config.update(exp['model']) + config.update(kwargs) + # Workaround for now: manually cast the lr to the correct type. + config['lr'] = float(config['lr']) + return config + + +def _get_model_class(key): + if 'abinet' in key: + from .abinet.system import ABINet as ModelClass + elif 'crnn' in key: + from .crnn.system import CRNN as ModelClass + elif 'parseq' in key: + from .parseq.system import PARSeq as ModelClass + elif 'trba' in key: + from .trba.system import TRBA as ModelClass + elif 'trbc' in key: + from .trba.system import TRBC as ModelClass + elif 'vitstr' in key: + from .vitstr.system import ViTSTR as ModelClass + else: + from .parseq.system import PARSeq as ModelClass + return ModelClass + + +def get_pretrained_weights(experiment): + try: + url = _WEIGHTS_URL[experiment] + except KeyError: + raise InvalidModelError(f"No pretrained weights found for '{experiment}'") from None + return torch.hub.load_state_dict_from_url(url=url, map_location='cpu', check_hash=True) + + +def create_model(experiment: str, pretrained: bool = False, **kwargs): + try: + config = _get_config(experiment, **kwargs) + except FileNotFoundError: + raise InvalidModelError(f"No configuration found for '{experiment}'") from None + ModelClass = _get_model_class(experiment) + model = ModelClass(**config) + if pretrained: + m = model.model if 'parseq' in experiment else model + m.load_state_dict(get_pretrained_weights(experiment)) + return model + + +def load_from_checkpoint(checkpoint_path: str, **kwargs): + if checkpoint_path.startswith('pretrained='): + model_id = checkpoint_path.split('=', maxsplit=1)[1] + model = create_model(model_id, True, **kwargs) + else: + ModelClass = _get_model_class(checkpoint_path) + model = ModelClass.load_from_checkpoint(checkpoint_path, **kwargs) + return model + + +def parse_model_args(args): + kwargs = {} + arg_types = {t.__name__: t for t in [int, float, str]} + arg_types['bool'] = lambda v: v.lower() == 'true' # special handling for bool + for arg in args: + name, value = arg.split('=', maxsplit=1) + name, arg_type = name.split(':', maxsplit=1) + kwargs[name] = arg_types[arg_type](value) + return kwargs + + +def init_weights(module: nn.Module, name: str = '', exclude: Sequence[str] = ()): + """Initialize the weights using the typical initialization schemes used in SOTA models.""" + if any(map(name.startswith, exclude)): + return + if isinstance(module, nn.Linear): + nn.init.trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.trunc_normal_(module.weight, std=0.02) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) diff --git a/IndicPhotoOCR/utils/strhub/models/vitstr/__init__.py b/IndicPhotoOCR/utils/strhub/models/vitstr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19e985679da1fcaa6deb306697993fd601892d6c --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/vitstr/__init__.py @@ -0,0 +1,12 @@ +r""" +Atienza, Rowel. "Vision Transformer for Fast and Efficient Scene Text Recognition." +In International Conference on Document Analysis and Recognition (ICDAR). 2021. + +https://arxiv.org/abs/2105.08582 + +All source files, except `system.py`, are based on the implementation listed below, +and hence are released under the license of the original. + +Source: https://github.com/roatienza/deep-text-recognition-benchmark +License: Apache License 2.0 (see LICENSE file in project root) +""" diff --git a/IndicPhotoOCR/utils/strhub/models/vitstr/model.py b/IndicPhotoOCR/utils/strhub/models/vitstr/model.py new file mode 100644 index 0000000000000000000000000000000000000000..62c5d551626c325243a4f0d055869384a59b3910 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/vitstr/model.py @@ -0,0 +1,28 @@ +""" +Implementation of ViTSTR based on timm VisionTransformer. + +TODO: +1) distilled deit backbone +2) base deit backbone + +Copyright 2021 Rowel Atienza +""" + +from timm.models.vision_transformer import VisionTransformer + + +class ViTSTR(VisionTransformer): + """ + ViTSTR is basically a ViT that uses DeiT weights. + Modified head to support a sequence of characters prediction for STR. + """ + + def forward(self, x, seqlen: int = 25): + x = self.forward_features(x) + x = x[:, :seqlen] + + # batch, seqlen, embsize + b, s, e = x.size() + x = x.reshape(b * s, e) + x = self.head(x).view(b, s, self.num_classes) + return x diff --git a/IndicPhotoOCR/utils/strhub/models/vitstr/system.py b/IndicPhotoOCR/utils/strhub/models/vitstr/system.py new file mode 100644 index 0000000000000000000000000000000000000000..37b762e1a074055873413655e35ef2605ffa8238 --- /dev/null +++ b/IndicPhotoOCR/utils/strhub/models/vitstr/system.py @@ -0,0 +1,79 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Sequence + +import torch +from torch import Tensor + +from pytorch_lightning.utilities.types import STEP_OUTPUT + +from strhub.models.base import CrossEntropySystem +from strhub.models.utils import init_weights + +from .model import ViTSTR as Model + + +class ViTSTR(CrossEntropySystem): + + def __init__( + self, + charset_train: str, + charset_test: str, + max_label_length: int, + batch_size: int, + lr: float, + warmup_pct: float, + weight_decay: float, + img_size: Sequence[int], + patch_size: Sequence[int], + embed_dim: int, + num_heads: int, + **kwargs: Any, + ) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + self.max_label_length = max_label_length + # We don't predict nor + self.model = Model( + img_size=img_size, + patch_size=patch_size, + depth=12, + mlp_ratio=4, + qkv_bias=True, + embed_dim=embed_dim, + num_heads=num_heads, + num_classes=len(self.tokenizer) - 2, + ) + # Non-zero weight init for the head + self.model.head.apply(init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {'model.' + n for n in self.model.no_weight_decay()} + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + logits = self.model.forward(images, max_length + 2) # +2 tokens for [GO] and [s] + # Truncate to conform to other models. [GO] in ViTSTR is actually used as the padding (therefore, ignored). + # First position corresponds to the class token, which is unused and ignored in the original work. + logits = logits[:, 1:] + return logits + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + loss = self.forward_logits_loss(images, labels)[1] + self.log('loss', loss) + return loss diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e36a8980c85818dc36895426b19ca07d36b357ad --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Bhashini Team@IIT Jodhpur + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 3e633a14fdf32fcfe4b4e4063b6ab0d633d60763..94f73b5ca0c3bd907c37d3a1ed40ec386a7e378f 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,192 @@ ---- -title: IndicPhotoOCR -emoji: 🐨 -colorFrom: pink -colorTo: purple -sdk: gradio -sdk_version: 5.6.0 -app_file: app.py -pinned: false -license: mit -short_description: Comprehensive Scene Text Recognition Toolkit across 13 India ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +

+ BharatOCR Logo +

+IndicPhotoOCR - Comprehensive Scene Text Recognition Toolkit
across 13 Indian Languages +

+

+
+ +![Open Source](https://img.shields.io/badge/Open%20Source-Bhashini-FF6C00) +[![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FBhashini-IITJ%2FBharatOCR&count_bg=%233D48C8&title_bg=%23555555&icon=&icon_color=%0C0983&title=hits&edge_flat=false)](https://hits.seeyoufarm.com) +[![GitHub stars](https://img.shields.io/github/stars/Bhashini-IITJ/BharatOCR.svg?style=social&label=Star&color=orange)](https://github.com/Bhashini-IITJ/BharatOCR/stargazers) +![GitHub forks](https://img.shields.io/github/forks/Bhashini-IITJ/BharatOCR?style=social) +[![Hugging Face](https://img.shields.io/badge/Hugging_Face-Demo-FF6C00?logo=Huggingface&logoColor=white)](https://huggingface.co/spaces/anikde/BharatOCR) + + +
+
+ + + + +IndicPhotoOCR is an advanced OCR toolkit designed for detecting, identifying, and recognizing text across 13 Indian languages, including Assamese, Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Meitei Odia, Punjabi, Tamil, Telugu, Urdu, and English. Built to handle the unique scripts and complex structures of Indian languages, IndicPhotoOCR provides robust detection and recognition capabilities, making it a valuable tool for processing multilingual documents and enhancing document analysis in these diverse scripts. + +![](static/pics/visualizeIndicPhotoOCR.png) +
+ +## Table of Content +[Updates](https://github.com/Bhashini-IITJ/BharatOCR/blob/main/README.md#updates)
+[Installation](https://github.com/Bhashini-IITJ/BharatOCR/blob/main/README.md#installation)
+[How to use](https://github.com/Bhashini-IITJ/BharatOCR/blob/main/README.md#how-to-use)
+[Acknowledgement](https://github.com/Bhashini-IITJ/BharatOCR/blob/main/README.md#acknowledgement)
+[Contact us](https://github.com/Bhashini-IITJ/BharatOCR/blob/main/README.md#contact-us)
+ +
+ + +## Updates +[November 2024]: Try demo in [huggingface space](https://huggingface.co/spaces/anikde/BharatOCR).\ +[November 2024]: Use this package in [Google Colab](https://colab.research.google.com/drive/1BILXjUF2kKKrzUJ_evubgLHl2busPiH2?usp=sharing).\ +[November 2024]: Added support for [10 languages](#config) in the recognition module.
+[September 2024]: Private repository created. +
+ +## Installation +Currently we need to manually create virtual environemnt. +```python +conda create -n indicphotoocr python=3.9 -y +conda activate indicphotoocr + + +git clone https://github.com/Bhashini-IITJ/IndicPhotoOCR.git +cd IndicPhotoOCR +``` +
+ CPU Installation + + ```bash + python setup.py sdist bdist_wheel + pip install dist/IndicPhotoOCR-1.1.0-py3-none-any.whl[cpu] + ``` +
+ +
+ CUDA 11.8 Installation + + ```bash + python setup.py sdist bdist_wheel + pip install ./dist/IndicPhotoOCR-1.1.0-py3-none-any.whl[cu118] --extra-index-url https://download.pytorch.org/whl/cu118 + ``` +
+ +
+ CUDA 12.1 Installation + + ```bash + python setup.py sdist bdist_wheel + pip install ./dist/IndicPhotoOCR-1.1.0-py3-none-any.whl[cu121] --extra-index-url https://download.pytorch.org/whl/cu121 + ``` +
+
+ +If you find any trouble with the above installation use the ```setup.sh``` script. +```bash +chmod +x setup.sh +./setup.sh +``` + +## Config +Currently this model works for hindi v/s english script identification and thereby hindi and english recognition. + +Detection Model: EAST\ +ScripIndetification Model: Hindi v/s English\ +Recognition Model: Hindi, English, Assamese, Bengali, Gujarati, Marathi, Odia, Punjabi, Tamil, Telugu. + +## How to use +### Detection + +```python +>>> from IndicPhotoOCR.ocr import OCR +# Create an object of OCR +>>> ocr_system = OCR(verbose=True) # for CPU --> OCR(device="cpu") + +# Get detections +>>> detections = ocr_system.detect("test_images/image_141.jpg") + +# Running text detection... +# 4334 text boxes before nms +# 1.027989387512207 + +# Save and visualize the detection results +>>> ocr_system.visualize_detection("test_images/image_141.jpg", detections) +# Image saved at: test.png +``` + +## Cropped Word Recognition +```python +>>> from IndicPhotoOCR.ocr import OCR +# Create an object of OCR +>>> ocr_system = OCR(verbose=True) # for CPU --> OCR(device="cpu") +# Get recognitions +>>> ocr_system.recognise("test_images/cropped_image/image_141_0.jpg", "hindi") +# Recognizing text in detected area... +# 'मण्डी' +``` + +## End-to-end Scene Text Recognition +```python +>>> from IndicPhotoOCR.ocr import OCR +# Create an object of OCR +>>> ocr_system = OCR(verbose=True) # for CPU --> OCR(device="cpu") +# Complete pipeline +>>> ocr_system.ocr("test_images/image_141.jpg") +# Running text detection... +# 4334 text boxes before nms +# 0.9715704917907715 +# Identifying script for the cropped area... +# Recognizing text in detected area... +# Recognized word: रोड +# Identifying script for the cropped area... +# Recognizing text in detected area... +# Recognized word: बाराखम्ब +# Identifying script for the cropped area... +# Recognizing text in detected area... +# Using cache found in /DATA1/ocrteam/.cache/torch/hub/baudm_parseq_main +# Recognized word: barakhaml +# Identifying script for the cropped area... +# Recognizing text in detected area... +# Recognized word: हाऊस +# Identifying script for the cropped area... +# Recognizing text in detected area... +# Using cache found in /DATA1/ocrteam/.cache/torch/hub/baudm_parseq_main +# Recognized word: mandi +# Identifying script for the cropped area... +# Recognizing text in detected area... +# Using cache found in /DATA1/ocrteam/.cache/torch/hub/baudm_parseq_main +# Recognized word: chowk +# Identifying script for the cropped area... +# Recognizing text in detected area... +# Recognized word: मण्डी +# Identifying script for the cropped area... +# Recognizing text in detected area... +# Using cache found in /DATA1/ocrteam/.cache/torch/hub/baudm_parseq_main +# Recognized word: road +# Identifying script for the cropped area... +# Recognizing text in detected area... +# Using cache found in /DATA1/ocrteam/.cache/torch/hub/baudm_parseq_main +# Recognized word: house +# Identifying script for the cropped area... +# Recognizing text in detected area... +# Using cache found in /DATA1/ocrteam/.cache/torch/hub/baudm_parseq_main +# Recognized word: rajiv +# Identifying script for the cropped area... +# Recognizing text in detected area... +# Recognized word: राजीव +# Identifying script for the cropped area... +# Recognizing text in detected area... +# Recognized word: चौक + + +``` + + + +## Acknowledgement + +Text Recognition - [PARseq](https://github.com/baudm/parseq)\ +EAST re-implemenation [repository](https://github.com/foamliu/EAST).
+National Language Translation Mission [Bhashini](https://bhashini.gov.in/). +## Contact us +For any queries, please contact us at: +- [Anik De](mailto:anekde@gmail.com) + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..41e70524aa99dd57856e852a9bab1b0cbda3aac9 --- /dev/null +++ b/app.py @@ -0,0 +1,118 @@ +import gradio as gr +from PIL import Image +import os +from IndicPhotoOCR.ocr import OCR # Ensure OCR class is saved in a file named ocr.py +from IndicPhotoOCR.theme import Seafoam + +# Initialize the OCR object for text detection and recognition +ocr = OCR(device="cpu", verbose=False) + +def process_image(image): + """ + Processes the uploaded image for text detection and recognition. + - Detects bounding boxes in the image + - Draws bounding boxes on the image and identifies script in each detected area + - Recognizes text in each cropped region and returns the annotated image and recognized text + + Parameters: + image (PIL.Image): The input image to be processed. + + Returns: + tuple: A PIL.Image with bounding boxes and a string of recognized text. + """ + + # Save the input image temporarily + image_path = "input_image.jpg" + image.save(image_path) + + # Detect bounding boxes on the image using OCR + detections = ocr.detect(image_path) + + # Draw bounding boxes on the image and save it as output + ocr.visualize_detection(image_path, detections, save_path="output_image.png") + + # Load the annotated image with bounding boxes drawn + output_image = Image.open("output_image.png") + + # Initialize list to hold recognized text from each detected area + recognized_texts = [] + pil_image = Image.open(image_path) + + # Process each detected bounding box for script identification and text recognition + for bbox in detections: + # Identify the script and crop the image to this region + script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox) + + if script_lang: # Only proceed if a script language is identified + # Recognize text in the cropped area + recognized_text = ocr.recognise(cropped_path, script_lang) + recognized_texts.append(recognized_text) + + # Combine recognized texts into a single string for display + recognized_texts_combined = " ".join(recognized_texts) + return output_image, recognized_texts_combined + +# Custom HTML for interface header with logos and alignment +interface_html = """ +
+
+ IITJ Logo +
+ Bhashini Logo +
+""" + + + +# Links to GitHub and Dataset repositories with GitHub icon +links_html = """ + +""" + +# Custom CSS to style the text box font size +custom_css = """ +.custom-textbox textarea { + font-size: 20px !important; +} +""" + +# Create an instance of the Seafoam theme for a consistent visual style +seafoam = Seafoam() + +# Define examples for users to try out +examples = [ + ["test_images/image_141.jpg"], + ["test_images/image_1164.jpg"] +] + +title = "

Developed by IITJ

" + +# Set up the Gradio Interface with the defined function and customizations +demo = gr.Interface( + fn=process_image, + inputs=gr.Image(type="pil", image_mode="RGB"), + outputs=[ + gr.Image(type="pil", label="Detected Bounding Boxes"), + gr.Textbox(label="Recognized Text", elem_classes="custom-textbox") + ], + title="IndicPhotoOCR - Indic Scene Text Recogniser Toolkit", + description=title+interface_html+links_html, + theme=seafoam, + css=custom_css, + examples=examples +) + +# # Server setup and launch configuration +if __name__ == "__main__": + server = "0.0.0.0" # IP address for server + port = 7865 # Port to run the server on + demo.launch(server_name=server, server_port=port) + +# demo.launch() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..f075eec24e314deb35a57251dddde0816a7fcb03 --- /dev/null +++ b/setup.py @@ -0,0 +1,79 @@ +from setuptools import setup, find_packages + +setup( + name="IndicPhotoOCR", + version="1.1.0", + description="Scene Text Recognition Toolkit across 13 Indian Languages which contains detection, script identification, and text recognition modules", + long_description=open("README.md").read() + "\n\n" + open("CHANGELOG.md").read(), + long_description_content_type="text/markdown", + author="Anik De", + author_email="anekde@gmail.com", + url="https://github.com/Bhashini-IITJ/IndicPhotoOCR", + packages=find_packages(), + python_requires='>=3.9', + install_requires=[ + # Your mandatory dependencies here + 'aiohappyeyeballs==2.4.3', + 'aiohttp==3.10.10', + 'aiosignal==1.3.1', + 'async-timeout==4.0.3', + 'attrs==24.2.0', + 'certifi==2024.8.30', + 'charset-normalizer==3.4.0', + 'click==8.1.7', + 'filelock==3.16.1', + 'frozenlist==1.5.0', + 'fsspec==2024.10.0', + 'huggingface-hub==0.26.1', + 'idna==3.10', + 'jinja2==3.1.4', + 'joblib==1.4.2', + 'lightning-utilities==0.11.8', + 'markupsafe==3.0.2', + 'mpmath==1.3.0', + 'multidict==6.1.0', + 'networkx==3.2.1', + 'nltk==3.9.1', + 'numpy==1.26.4', + 'packaging==24.1', + 'pillow==11.0.0', + 'propcache==0.2.0', + 'pytorch-lightning==2.4.0', + 'pyyaml==6.0.2', + 'regex==2024.9.11', + 'requests==2.32.3', + 'safetensors==0.4.5', + 'sympy==1.13.1', + 'timm==1.0.11', + 'torchmetrics==1.5.1', + 'tqdm==4.66.5', + 'typing-extensions==4.12.2', + 'urllib3==2.2.3', + 'yarl==1.16.0', + 'opencv-python==4.10.0.84', + 'shapely==2.0.6', + 'openai-clip==1.0.1', + 'lmdb==1.5.1' + + ], + extras_require={ + 'cu118': [ + 'torch==2.5.0+cu118', + 'torchvision==0.20.0+cu118', + # Any additional packages specific to cu118 + ], + 'cu121': [ + 'torch==2.5.0+cu121', + 'torchvision==0.20.0+cu121', + # Any additional packages specific to cu121 + ], + 'cpu': [ + 'torch==2.5.0', + 'torchvision==0.20.0', + # Any additional packages specific to CPU + ], + 'extra': [ + 'six==1.16.0', # Your other extra requirements + ], + }, +)