hantech commited on
Commit
bd22b5e
1 Parent(s): 747c9ad

Upload 38 files

Browse files
Files changed (37) hide show
  1. vietocr/.gitignore +140 -0
  2. vietocr/config.yml +69 -0
  3. vietocr/config/__init__.py +0 -0
  4. vietocr/config/base.yml +58 -0
  5. vietocr/config/resnet-transformer.yml +16 -0
  6. vietocr/config/resnet_fpn_transformer.yml +9 -0
  7. vietocr/config/vgg-convseq2seq.yml +40 -0
  8. vietocr/config/vgg-seq2seq.yml +37 -0
  9. vietocr/config/vgg-transformer.yml +26 -0
  10. vietocr/vietocr/__init__.py +0 -0
  11. vietocr/vietocr/loader/__init__.py +0 -0
  12. vietocr/vietocr/loader/aug.py +48 -0
  13. vietocr/vietocr/loader/dataloader.py +205 -0
  14. vietocr/vietocr/loader/dataloader_v1.py +155 -0
  15. vietocr/vietocr/model/__init__.py +0 -0
  16. vietocr/vietocr/model/backbone/__init__.py +0 -0
  17. vietocr/vietocr/model/backbone/cnn.py +28 -0
  18. vietocr/vietocr/model/backbone/resnet.py +140 -0
  19. vietocr/vietocr/model/backbone/vgg.py +50 -0
  20. vietocr/vietocr/model/beam.py +104 -0
  21. vietocr/vietocr/model/seqmodel/__init__.py +0 -0
  22. vietocr/vietocr/model/seqmodel/convseq2seq.py +324 -0
  23. vietocr/vietocr/model/seqmodel/seq2seq.py +175 -0
  24. vietocr/vietocr/model/seqmodel/transformer.py +124 -0
  25. vietocr/vietocr/model/trainer.py +363 -0
  26. vietocr/vietocr/model/transformerocr.py +45 -0
  27. vietocr/vietocr/model/vocab.py +36 -0
  28. vietocr/vietocr/optim/__init__.py +0 -0
  29. vietocr/vietocr/optim/labelsmoothingloss.py +25 -0
  30. vietocr/vietocr/optim/optim.py +58 -0
  31. vietocr/vietocr/tool/__init__.py +0 -0
  32. vietocr/vietocr/tool/config.py +40 -0
  33. vietocr/vietocr/tool/create_dataset.py +105 -0
  34. vietocr/vietocr/tool/logger.py +17 -0
  35. vietocr/vietocr/tool/predictor.py +85 -0
  36. vietocr/vietocr/tool/translate.py +172 -0
  37. vietocr/vietocr/tool/utils.py +93 -0
vietocr/.gitignore ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
133
+
134
+ # pytype static type analyzer
135
+ .pytype/
136
+
137
+ # Cython debug symbols
138
+ cython_debug/
139
+
140
+ my_train_vietocr.py/
vietocr/config.yml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aug:
2
+ image_aug: true
3
+ masked_language_model: true
4
+ backbone: vgg19_bn
5
+ cnn:
6
+ hidden: 256
7
+ ks:
8
+ - - 2
9
+ - 2
10
+ - - 2
11
+ - 2
12
+ - - 2
13
+ - 1
14
+ - - 2
15
+ - 1
16
+ - - 1
17
+ - 1
18
+ pretrained: true
19
+ ss:
20
+ - - 2
21
+ - 2
22
+ - - 2
23
+ - 2
24
+ - - 2
25
+ - 1
26
+ - - 2
27
+ - 1
28
+ - - 1
29
+ - 1
30
+ dataloader:
31
+ num_workers: 0
32
+ pin_memory: true
33
+ dataset:
34
+ data_root: /Users/bmd1905/Desktop/dataset/data_line
35
+ image_height: 32
36
+ image_max_width: 512
37
+ image_min_width: 32
38
+ name: hw
39
+ train_annotation: train_line_annotation.txt
40
+ valid_annotation: test_line_annotation.txt
41
+ device: mps
42
+ optimizer:
43
+ max_lr: 0.0003
44
+ pct_start: 0.1
45
+ predictor:
46
+ beamsearch: false
47
+ pretrain: https://vocr.vn/data/vietocr/vgg_transformer.pth
48
+ quiet: false
49
+ seq_modeling: transformer
50
+ trainer:
51
+ batch_size: 1
52
+ checkpoint: ./checkpoint/transformerocr_checkpoint.pth
53
+ export: ./weights/seq2seq_test_local.pth
54
+ iters: 100
55
+ log: ./train.log
56
+ metrics: 100
57
+ print_every: 1
58
+ valid_every: 10
59
+ transformer:
60
+ d_model: 256
61
+ dim_feedforward: 2048
62
+ max_seq_length: 1024
63
+ nhead: 8
64
+ num_decoder_layers: 6
65
+ num_encoder_layers: 6
66
+ pos_dropout: 0.1
67
+ trans_dropout: 0.1
68
+ vocab: 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&''()*+,-./:;<=>?@[\]^_`{|}~ '
69
+ weights: https://vocr.vn/data/vietocr/vgg_transformer.pth
vietocr/config/__init__.py ADDED
File without changes
vietocr/config/base.yml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # change to list chars of your dataset or use default vietnamese chars
2
+ vocab: 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&''()*+,-./:;<=>?@[\]^_`{|}~ '
3
+
4
+ # cpu, cuda, cuda:0
5
+ device: cuda:0
6
+
7
+ seq_modeling: transformer
8
+ transformer:
9
+ d_model: 256
10
+ nhead: 8
11
+ num_encoder_layers: 6
12
+ num_decoder_layers: 6
13
+ dim_feedforward: 2048
14
+ max_seq_length: 1024
15
+ pos_dropout: 0.1
16
+ trans_dropout: 0.1
17
+
18
+ optimizer:
19
+ max_lr: 0.0003
20
+ pct_start: 0.1
21
+
22
+ trainer:
23
+ batch_size: 32
24
+ print_every: 200
25
+ valid_every: 4000
26
+ iters: 100000
27
+ # where to save our model for prediction
28
+ export: ./weights/transformerocr.pth
29
+ checkpoint: ./checkpoint/transformerocr_checkpoint.pth
30
+ log: ./train.log
31
+ # null to disable compuate accuracy, or change to number of sample to enable validiation while training
32
+ metrics: null
33
+
34
+ dataset:
35
+ # name of your dataset
36
+ name: data
37
+ # path to annotation and image
38
+ data_root: ./img/
39
+ train_annotation: annotation_train.txt
40
+ valid_annotation: annotation_val_small.txt
41
+ # resize image to 32 height, larger height will increase accuracy
42
+ image_height: 32
43
+ image_min_width: 32
44
+ image_max_width: 512
45
+
46
+ dataloader:
47
+ num_workers: 3
48
+ pin_memory: True
49
+
50
+ aug:
51
+ image_aug: true
52
+ masked_language_model: true
53
+
54
+ predictor:
55
+ # disable or enable beamsearch while prediction, use beamsearch will be slower
56
+ beamsearch: False
57
+
58
+ quiet: False
vietocr/config/resnet-transformer.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrain:
2
+ id_or_url: 13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA
3
+ md5: 7068030afe2e8fc639d0e1e2c25612b3
4
+ cached: /tmp/tranformerorc.pth
5
+
6
+ weights: https://drive.google.com/uc?id=12dTOZ9VP7ZVzwQgVvqBWz5JO5RXXW5NY
7
+
8
+ backbone: resnet50
9
+ cnn:
10
+ ss:
11
+ - [2, 2]
12
+ - [2, 1]
13
+ - [2, 1]
14
+ - [2, 1]
15
+ - [1, 1]
16
+ hidden: 256
vietocr/config/resnet_fpn_transformer.yml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ pretrain:
2
+ id_or_url: 13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA
3
+ md5: 7068030afe2e8fc639d0e1e2c25612b3
4
+ cached: /tmp/tranformerorc.pth
5
+
6
+ weights: https://drive.google.com/uc?id=12dTOZ9VP7ZVzwQgVvqBWz5JO5RXXW5NY
7
+
8
+ backbone: resnet50_fpn
9
+ cnn: {}
vietocr/config/vgg-convseq2seq.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrain:
2
+ id_or_url: 13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA
3
+ md5: fbefa85079ad9001a71eb1bf47a93785
4
+ cached: /tmp/tranformerorc.pth
5
+
6
+ # url or local path
7
+ weights: https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA
8
+
9
+ backbone: vgg19_bn
10
+ cnn:
11
+ # pooling stride size
12
+ ss:
13
+ - [2, 2]
14
+ - [2, 2]
15
+ - [2, 1]
16
+ - [2, 1]
17
+ - [1, 1]
18
+ # pooling kernel size
19
+ ks:
20
+ - [2, 2]
21
+ - [2, 2]
22
+ - [2, 1]
23
+ - [2, 1]
24
+ - [1, 1]
25
+ # dim of ouput feature map
26
+ hidden: 256
27
+
28
+ seq_modeling: convseq2seq
29
+ transformer:
30
+ emb_dim: 256
31
+ hid_dim: 512
32
+ enc_layers: 10
33
+ dec_layers: 10
34
+ enc_kernel_size: 3
35
+ dec_kernel_size: 3
36
+ dropout: 0.1
37
+ pad_idx: 0
38
+ device: cuda:1
39
+ enc_max_length: 512
40
+ dec_max_length: 512
vietocr/config/vgg-seq2seq.yml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # for train
2
+ pretrain: https://vocr.vn/data/vietocr/vgg_seq2seq.pth
3
+
4
+ # url or local path (for predict)
5
+ weights: https://vocr.vn/data/vietocr/vgg_seq2seq.pth
6
+
7
+ backbone: vgg19_bn
8
+ cnn:
9
+ # pooling stride size
10
+ ss:
11
+ - [2, 2]
12
+ - [2, 2]
13
+ - [2, 1]
14
+ - [2, 1]
15
+ - [1, 1]
16
+ # pooling kernel size
17
+ ks:
18
+ - [2, 2]
19
+ - [2, 2]
20
+ - [2, 1]
21
+ - [2, 1]
22
+ - [1, 1]
23
+ # dim of ouput feature map
24
+ hidden: 256
25
+
26
+ seq_modeling: seq2seq
27
+
28
+ transformer:
29
+ encoder_hidden: 256
30
+ decoder_hidden: 256
31
+ img_channel: 256
32
+ decoder_embedded: 256
33
+ dropout: 0.1
34
+
35
+ optimizer:
36
+ max_lr: 0.001
37
+ pct_start: 0.1
vietocr/config/vgg-transformer.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # for training
2
+ pretrain: https://vocr.vn/data/vietocr/vgg_transformer.pth
3
+
4
+ # url or local path (predict)
5
+ weights: https://vocr.vn/data/vietocr/vgg_transformer.pth
6
+
7
+ backbone: vgg19_bn
8
+ cnn:
9
+ pretrained: True
10
+ # pooling stride size
11
+ ss:
12
+ - [2, 2]
13
+ - [2, 2]
14
+ - [2, 1]
15
+ - [2, 1]
16
+ - [1, 1]
17
+ # pooling kernel size
18
+ ks:
19
+ - [2, 2]
20
+ - [2, 2]
21
+ - [2, 1]
22
+ - [2, 1]
23
+ - [1, 1]
24
+ # dim of ouput feature map
25
+ hidden: 256
26
+
vietocr/vietocr/__init__.py ADDED
File without changes
vietocr/vietocr/loader/__init__.py ADDED
File without changes
vietocr/vietocr/loader/aug.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ from imgaug import augmenters as iaa
5
+ import imgaug as ia
6
+
7
+ class ImgAugTransform:
8
+ def __init__(self):
9
+ sometimes = lambda aug: iaa.Sometimes(0.3, aug)
10
+
11
+ self.aug = iaa.Sequential(iaa.SomeOf((1, 5),
12
+ [
13
+ # blur
14
+
15
+ sometimes(iaa.OneOf([iaa.GaussianBlur(sigma=(0, 1.0)),
16
+ iaa.MotionBlur(k=3)])),
17
+
18
+ # color
19
+ sometimes(iaa.AddToHueAndSaturation(value=(-10, 10), per_channel=True)),
20
+ sometimes(iaa.SigmoidContrast(gain=(3, 10), cutoff=(0.4, 0.6), per_channel=True)),
21
+ sometimes(iaa.Invert(0.25, per_channel=0.5)),
22
+ sometimes(iaa.Solarize(0.5, threshold=(32, 128))),
23
+ sometimes(iaa.Dropout2d(p=0.5)),
24
+ sometimes(iaa.Multiply((0.5, 1.5), per_channel=0.5)),
25
+ sometimes(iaa.Add((-40, 40), per_channel=0.5)),
26
+
27
+ sometimes(iaa.JpegCompression(compression=(5, 80))),
28
+
29
+ # distort
30
+ sometimes(iaa.Crop(percent=(0.01, 0.05), sample_independently=True)),
31
+ sometimes(iaa.PerspectiveTransform(scale=(0.01, 0.01))),
32
+ sometimes(iaa.Affine(scale=(0.7, 1.3), translate_percent=(-0.1, 0.1),
33
+ # rotate=(-5, 5), shear=(-5, 5),
34
+ order=[0, 1], cval=(0, 255),
35
+ mode=ia.ALL)),
36
+ sometimes(iaa.PiecewiseAffine(scale=(0.01, 0.01))),
37
+ sometimes(iaa.OneOf([iaa.Dropout(p=(0, 0.1)),
38
+ iaa.CoarseDropout(p=(0, 0.1), size_percent=(0.02, 0.25))])),
39
+
40
+ ],
41
+ random_order=True),
42
+ random_order=True)
43
+
44
+ def __call__(self, img):
45
+ img = np.array(img)
46
+ img = self.aug.augment_image(img)
47
+ img = Image.fromarray(img)
48
+ return img
vietocr/vietocr/loader/dataloader.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import random
4
+ from PIL import Image
5
+ from PIL import ImageFile
6
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
7
+
8
+ from collections import defaultdict
9
+ import numpy as np
10
+ import torch
11
+ import lmdb
12
+ import six
13
+ import time
14
+ from tqdm import tqdm
15
+
16
+ from torch.utils.data import Dataset
17
+ from torch.utils.data.sampler import Sampler
18
+ from vietocr.vietocr.tool.translate import process_image
19
+ from vietocr.vietocr.tool.create_dataset import createDataset
20
+ from vietocr.vietocr.tool.translate import resize
21
+
22
+ class OCRDataset(Dataset):
23
+ def __init__(self, lmdb_path, root_dir, annotation_path, vocab, image_height=32, image_min_width=32, image_max_width=512, transform=None):
24
+ self.root_dir = root_dir
25
+ self.annotation_path = os.path.join(root_dir, annotation_path)
26
+ self.vocab = vocab
27
+ self.transform = transform
28
+
29
+ self.image_height = image_height
30
+ self.image_min_width = image_min_width
31
+ self.image_max_width = image_max_width
32
+
33
+ self.lmdb_path = lmdb_path
34
+
35
+ if os.path.isdir(self.lmdb_path):
36
+ print('{} exists. Remove folder if you want to create new dataset'.format(self.lmdb_path))
37
+ sys.stdout.flush()
38
+ else:
39
+ createDataset(self.lmdb_path, root_dir, annotation_path)
40
+
41
+ self.env = lmdb.open(
42
+ self.lmdb_path,
43
+ max_readers=8,
44
+ readonly=True,
45
+ lock=False,
46
+ readahead=False,
47
+ meminit=False)
48
+ self.txn = self.env.begin(write=False)
49
+
50
+ nSamples = int(self.txn.get('num-samples'.encode()))
51
+ self.nSamples = nSamples
52
+
53
+ self.build_cluster_indices()
54
+
55
+ def build_cluster_indices(self):
56
+ self.cluster_indices = defaultdict(list)
57
+
58
+ pbar = tqdm(range(self.__len__()),
59
+ desc='{} build cluster'.format(self.lmdb_path),
60
+ ncols = 100, position=0, leave=True)
61
+
62
+ for i in pbar:
63
+ bucket = self.get_bucket(i)
64
+ self.cluster_indices[bucket].append(i)
65
+
66
+
67
+ def get_bucket(self, idx):
68
+ key = 'dim-%09d'%idx
69
+
70
+ dim_img = self.txn.get(key.encode())
71
+ dim_img = np.fromstring(dim_img, dtype=np.int32)
72
+ imgH, imgW = dim_img
73
+
74
+ new_w, image_height = resize(imgW, imgH, self.image_height, self.image_min_width, self.image_max_width)
75
+
76
+ return new_w
77
+
78
+ def read_buffer(self, idx):
79
+ img_file = 'image-%09d'%idx
80
+ label_file = 'label-%09d'%idx
81
+ path_file = 'path-%09d'%idx
82
+
83
+ imgbuf = self.txn.get(img_file.encode())
84
+
85
+ label = self.txn.get(label_file.encode()).decode()
86
+ img_path = self.txn.get(path_file.encode()).decode()
87
+
88
+ buf = six.BytesIO()
89
+ buf.write(imgbuf)
90
+ buf.seek(0)
91
+
92
+ return buf, label, img_path
93
+
94
+ def read_data(self, idx):
95
+ buf, label, img_path = self.read_buffer(idx)
96
+
97
+ img = Image.open(buf).convert('RGB')
98
+
99
+ if self.transform:
100
+ img = self.transform(img)
101
+
102
+ img_bw = process_image(img, self.image_height, self.image_min_width, self.image_max_width)
103
+
104
+ word = self.vocab.encode(label)
105
+
106
+ return img_bw, word, img_path
107
+
108
+ def __getitem__(self, idx):
109
+ img, word, img_path = self.read_data(idx)
110
+
111
+ img_path = os.path.join(self.root_dir, img_path)
112
+
113
+ sample = {'img': img, 'word': word, 'img_path': img_path}
114
+
115
+ return sample
116
+
117
+ def __len__(self):
118
+ return self.nSamples
119
+
120
+ class ClusterRandomSampler(Sampler):
121
+
122
+ def __init__(self, data_source, batch_size, shuffle=True):
123
+ self.data_source = data_source
124
+ self.batch_size = batch_size
125
+ self.shuffle = shuffle
126
+
127
+ def flatten_list(self, lst):
128
+ return [item for sublist in lst for item in sublist]
129
+
130
+ def __iter__(self):
131
+ batch_lists = []
132
+ for cluster, cluster_indices in self.data_source.cluster_indices.items():
133
+ if self.shuffle:
134
+ random.shuffle(cluster_indices)
135
+
136
+ batches = [cluster_indices[i:i + self.batch_size] for i in range(0, len(cluster_indices), self.batch_size)]
137
+ batches = [_ for _ in batches if len(_) == self.batch_size]
138
+ if self.shuffle:
139
+ random.shuffle(batches)
140
+
141
+ batch_lists.append(batches)
142
+
143
+ lst = self.flatten_list(batch_lists)
144
+ if self.shuffle:
145
+ random.shuffle(lst)
146
+
147
+ lst = self.flatten_list(lst)
148
+
149
+ return iter(lst)
150
+
151
+ def __len__(self):
152
+ return len(self.data_source)
153
+
154
+ class Collator(object):
155
+ def __init__(self, masked_language_model=True):
156
+ self.masked_language_model = masked_language_model
157
+
158
+ def __call__(self, batch):
159
+ filenames = []
160
+ img = []
161
+ target_weights = []
162
+ tgt_input = []
163
+ max_label_len = max(len(sample['word']) for sample in batch)
164
+ for sample in batch:
165
+ img.append(sample['img'])
166
+ filenames.append(sample['img_path'])
167
+ label = sample['word']
168
+ label_len = len(label)
169
+
170
+
171
+ tgt = np.concatenate((
172
+ label,
173
+ np.zeros(max_label_len - label_len, dtype=np.int32)))
174
+ tgt_input.append(tgt)
175
+
176
+ one_mask_len = label_len - 1
177
+
178
+ target_weights.append(np.concatenate((
179
+ np.ones(one_mask_len, dtype=np.float32),
180
+ np.zeros(max_label_len - one_mask_len,dtype=np.float32))))
181
+
182
+ img = np.array(img, dtype=np.float32)
183
+
184
+
185
+ tgt_input = np.array(tgt_input, dtype=np.int64).T
186
+ tgt_output = np.roll(tgt_input, -1, 0).T
187
+ tgt_output[:, -1]=0
188
+
189
+ # random mask token
190
+ if self.masked_language_model:
191
+ mask = np.random.random(size=tgt_input.shape) < 0.05
192
+ mask = mask & (tgt_input != 0) & (tgt_input != 1) & (tgt_input != 2)
193
+ tgt_input[mask] = 3
194
+
195
+ tgt_padding_mask = np.array(target_weights)==0
196
+
197
+ rs = {
198
+ 'img': torch.FloatTensor(img),
199
+ 'tgt_input': torch.LongTensor(tgt_input),
200
+ 'tgt_output': torch.LongTensor(tgt_output),
201
+ 'tgt_padding_mask': torch.BoolTensor(tgt_padding_mask),
202
+ 'filenames': filenames
203
+ }
204
+
205
+ return rs
vietocr/vietocr/loader/dataloader_v1.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import random
5
+ from vietocr.vietocr.model.vocab import Vocab
6
+ from vietocr.vietocr.tool.translate import process_image
7
+ import os
8
+ from collections import defaultdict
9
+ import math
10
+ from prefetch_generator import background
11
+
12
+ class BucketData(object):
13
+ def __init__(self, device):
14
+ self.max_label_len = 0
15
+ self.data_list = []
16
+ self.label_list = []
17
+ self.file_list = []
18
+ self.device = device
19
+
20
+ def append(self, datum, label, filename):
21
+ self.data_list.append(datum)
22
+ self.label_list.append(label)
23
+ self.file_list.append(filename)
24
+
25
+ self.max_label_len = max(len(label), self.max_label_len)
26
+
27
+ return len(self.data_list)
28
+
29
+ def flush_out(self):
30
+ """
31
+ Shape:
32
+ - img: (N, C, H, W)
33
+ - tgt_input: (T, N)
34
+ - tgt_output: (N, T)
35
+ - tgt_padding_mask: (N, T)
36
+ """
37
+ # encoder part
38
+ img = np.array(self.data_list, dtype=np.float32)
39
+
40
+ # decoder part
41
+ target_weights = []
42
+ tgt_input = []
43
+ for label in self.label_list:
44
+ label_len = len(label)
45
+
46
+ tgt = np.concatenate((
47
+ label,
48
+ np.zeros(self.max_label_len - label_len, dtype=np.int32)))
49
+ tgt_input.append(tgt)
50
+
51
+ one_mask_len = label_len - 1
52
+
53
+ target_weights.append(np.concatenate((
54
+ np.ones(one_mask_len, dtype=np.float32),
55
+ np.zeros(self.max_label_len - one_mask_len,dtype=np.float32))))
56
+
57
+ # reshape to fit input shape
58
+ tgt_input = np.array(tgt_input, dtype=np.int64).T
59
+ tgt_output = np.roll(tgt_input, -1, 0).T
60
+ tgt_output[:, -1]=0
61
+
62
+ tgt_padding_mask = np.array(target_weights)==0
63
+
64
+ filenames = self.file_list
65
+
66
+ self.data_list, self.label_list, self.file_list = [], [], []
67
+ self.max_label_len = 0
68
+
69
+ rs = {
70
+ 'img': torch.FloatTensor(img).to(self.device),
71
+ 'tgt_input': torch.LongTensor(tgt_input).to(self.device),
72
+ 'tgt_output': torch.LongTensor(tgt_output).to(self.device),
73
+ 'tgt_padding_mask':torch.BoolTensor(tgt_padding_mask).to(self.device),
74
+ 'filenames': filenames
75
+ }
76
+
77
+ return rs
78
+
79
+ def __len__(self):
80
+ return len(self.data_list)
81
+
82
+ def __iadd__(self, other):
83
+ self.data_list += other.data_list
84
+ self.label_list += other.label_list
85
+ self.max_label_len = max(self.max_label_len, other.max_label_len)
86
+ self.max_width = max(self.max_width, other.max_width)
87
+
88
+ def __add__(self, other):
89
+ res = BucketData()
90
+ res.data_list = self.data_list + other.data_list
91
+ res.label_list = self.label_list + other.label_list
92
+ res.max_width = max(self.max_width, other.max_width)
93
+ res.max_label_len = max((self.max_label_len, other.max_label_len))
94
+ return res
95
+
96
+ class DataGen(object):
97
+
98
+ def __init__(self,data_root, annotation_fn, vocab, device, image_height=32, image_min_width=32, image_max_width=512):
99
+
100
+ self.image_height = image_height
101
+ self.image_min_width = image_min_width
102
+ self.image_max_width = image_max_width
103
+
104
+ self.data_root = data_root
105
+ self.annotation_path = os.path.join(data_root, annotation_fn)
106
+
107
+ self.vocab = vocab
108
+ self.device = device
109
+
110
+ self.clear()
111
+
112
+ def clear(self):
113
+ self.bucket_data = defaultdict(lambda: BucketData(self.device))
114
+
115
+ @background(max_prefetch=1)
116
+ def gen(self, batch_size, last_batch=True):
117
+ with open(self.annotation_path, 'r') as ann_file:
118
+ lines = ann_file.readlines()
119
+ np.random.shuffle(lines)
120
+ for l in lines:
121
+
122
+ img_path, lex = l.strip().split('\t')
123
+
124
+ img_path = os.path.join(self.data_root, img_path)
125
+
126
+ try:
127
+ img_bw, word = self.read_data(img_path, lex)
128
+ except IOError:
129
+ print('ioread image:{}'.format(img_path))
130
+
131
+ width = img_bw.shape[-1]
132
+
133
+ bs = self.bucket_data[width].append(img_bw, word, img_path)
134
+ if bs >= batch_size:
135
+ b = self.bucket_data[width].flush_out()
136
+ yield b
137
+
138
+ if last_batch:
139
+ for bucket in self.bucket_data.values():
140
+ if len(bucket) > 0:
141
+ b = bucket.flush_out()
142
+ yield b
143
+
144
+ self.clear()
145
+
146
+ def read_data(self, img_path, lex):
147
+
148
+ with open(img_path, 'rb') as img_file:
149
+ img = Image.open(img_file).convert('RGB')
150
+ img_bw = process_image(img, self.image_height, self.image_min_width, self.image_max_width)
151
+
152
+ word = self.vocab.encode(lex)
153
+
154
+ return img_bw, word
155
+
vietocr/vietocr/model/__init__.py ADDED
File without changes
vietocr/vietocr/model/backbone/__init__.py ADDED
File without changes
vietocr/vietocr/model/backbone/cnn.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ import vietocr.vietocr.model.backbone.vgg as vgg
5
+ from vietocr.vietocr.model.backbone.resnet import Resnet50
6
+
7
+ class CNN(nn.Module):
8
+ def __init__(self, backbone, **kwargs):
9
+ super(CNN, self).__init__()
10
+
11
+ if backbone == 'vgg11_bn':
12
+ self.model = vgg.vgg11_bn(**kwargs)
13
+ elif backbone == 'vgg19_bn':
14
+ self.model = vgg.vgg19_bn(**kwargs)
15
+ elif backbone == 'resnet50':
16
+ self.model = Resnet50(**kwargs)
17
+
18
+ def forward(self, x):
19
+ return self.model(x)
20
+
21
+ def freeze(self):
22
+ for name, param in self.model.features.named_parameters():
23
+ if name != 'last_conv_1x1':
24
+ param.requires_grad = False
25
+
26
+ def unfreeze(self):
27
+ for param in self.model.features.parameters():
28
+ param.requires_grad = True
vietocr/vietocr/model/backbone/resnet.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class BasicBlock(nn.Module):
5
+ expansion = 1
6
+
7
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
8
+ super(BasicBlock, self).__init__()
9
+ self.conv1 = self._conv3x3(inplanes, planes)
10
+ self.bn1 = nn.BatchNorm2d(planes)
11
+ self.conv2 = self._conv3x3(planes, planes)
12
+ self.bn2 = nn.BatchNorm2d(planes)
13
+ self.relu = nn.ReLU(inplace=True)
14
+ self.downsample = downsample
15
+ self.stride = stride
16
+
17
+ def _conv3x3(self, in_planes, out_planes, stride=1):
18
+ "3x3 convolution with padding"
19
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20
+ padding=1, bias=False)
21
+
22
+ def forward(self, x):
23
+ residual = x
24
+
25
+ out = self.conv1(x)
26
+ out = self.bn1(out)
27
+ out = self.relu(out)
28
+
29
+ out = self.conv2(out)
30
+ out = self.bn2(out)
31
+
32
+ if self.downsample is not None:
33
+ residual = self.downsample(x)
34
+ out += residual
35
+ out = self.relu(out)
36
+
37
+ return out
38
+
39
+ class ResNet(nn.Module):
40
+
41
+ def __init__(self, input_channel, output_channel, block, layers):
42
+ super(ResNet, self).__init__()
43
+
44
+ self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
45
+
46
+ self.inplanes = int(output_channel / 8)
47
+ self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
48
+ kernel_size=3, stride=1, padding=1, bias=False)
49
+ self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
50
+ self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
51
+ kernel_size=3, stride=1, padding=1, bias=False)
52
+ self.bn0_2 = nn.BatchNorm2d(self.inplanes)
53
+ self.relu = nn.ReLU(inplace=True)
54
+
55
+ self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
56
+ self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
57
+ self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
58
+ 0], kernel_size=3, stride=1, padding=1, bias=False)
59
+ self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
60
+
61
+ self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
62
+ self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
63
+ self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
64
+ 1], kernel_size=3, stride=1, padding=1, bias=False)
65
+ self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
66
+
67
+ self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
68
+ self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
69
+ self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
70
+ 2], kernel_size=3, stride=1, padding=1, bias=False)
71
+ self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
72
+
73
+ self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
74
+ self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
75
+ 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
76
+ self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
77
+ self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
78
+ 3], kernel_size=2, stride=1, padding=0, bias=False)
79
+ self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
80
+
81
+ def _make_layer(self, block, planes, blocks, stride=1):
82
+ downsample = None
83
+ if stride != 1 or self.inplanes != planes * block.expansion:
84
+ downsample = nn.Sequential(
85
+ nn.Conv2d(self.inplanes, planes * block.expansion,
86
+ kernel_size=1, stride=stride, bias=False),
87
+ nn.BatchNorm2d(planes * block.expansion),
88
+ )
89
+
90
+ layers = []
91
+ layers.append(block(self.inplanes, planes, stride, downsample))
92
+ self.inplanes = planes * block.expansion
93
+ for i in range(1, blocks):
94
+ layers.append(block(self.inplanes, planes))
95
+
96
+ return nn.Sequential(*layers)
97
+
98
+ def forward(self, x):
99
+ x = self.conv0_1(x)
100
+ x = self.bn0_1(x)
101
+ x = self.relu(x)
102
+ x = self.conv0_2(x)
103
+ x = self.bn0_2(x)
104
+ x = self.relu(x)
105
+
106
+ x = self.maxpool1(x)
107
+ x = self.layer1(x)
108
+ x = self.conv1(x)
109
+ x = self.bn1(x)
110
+ x = self.relu(x)
111
+
112
+ x = self.maxpool2(x)
113
+ x = self.layer2(x)
114
+ x = self.conv2(x)
115
+ x = self.bn2(x)
116
+ x = self.relu(x)
117
+
118
+ x = self.maxpool3(x)
119
+ x = self.layer3(x)
120
+ x = self.conv3(x)
121
+ x = self.bn3(x)
122
+ x = self.relu(x)
123
+
124
+ x = self.layer4(x)
125
+ x = self.conv4_1(x)
126
+ x = self.bn4_1(x)
127
+ x = self.relu(x)
128
+ x = self.conv4_2(x)
129
+ x = self.bn4_2(x)
130
+ conv = self.relu(x)
131
+
132
+ conv = conv.transpose(-1, -2)
133
+ conv = conv.flatten(2)
134
+ conv = conv.permute(-1, 0, 1)
135
+
136
+ return conv
137
+
138
+ def Resnet50(ss, hidden):
139
+ return ResNet(3, hidden, BasicBlock, [1, 2, 5, 3])
140
+
vietocr/vietocr/model/backbone/vgg.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchvision import models
4
+ from einops import rearrange
5
+ from torchvision.models._utils import IntermediateLayerGetter
6
+
7
+
8
+ class Vgg(nn.Module):
9
+ def __init__(self, name, ss, ks, hidden, pretrained=True, dropout=0.5):
10
+ super(Vgg, self).__init__()
11
+
12
+ if name == 'vgg11_bn':
13
+ cnn = models.vgg11_bn(pretrained=pretrained)
14
+ elif name == 'vgg19_bn':
15
+ cnn = models.vgg19_bn(pretrained=pretrained)
16
+
17
+ pool_idx = 0
18
+
19
+ for i, layer in enumerate(cnn.features):
20
+ if isinstance(layer, torch.nn.MaxPool2d):
21
+ cnn.features[i] = torch.nn.AvgPool2d(kernel_size=ks[pool_idx], stride=ss[pool_idx], padding=0)
22
+ pool_idx += 1
23
+
24
+ self.features = cnn.features
25
+ self.dropout = nn.Dropout(dropout)
26
+ self.last_conv_1x1 = nn.Conv2d(512, hidden, 1)
27
+
28
+ def forward(self, x):
29
+ """
30
+ Shape:
31
+ - x: (N, C, H, W)
32
+ - output: (W, N, C)
33
+ """
34
+
35
+ conv = self.features(x)
36
+ conv = self.dropout(conv)
37
+ conv = self.last_conv_1x1(conv)
38
+
39
+ # conv = rearrange(conv, 'b d h w -> b d (w h)')
40
+ conv = conv.transpose(-1, -2)
41
+ conv = conv.flatten(2)
42
+ conv = conv.permute(-1, 0, 1)
43
+ return conv
44
+
45
+ def vgg11_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
46
+ return Vgg('vgg11_bn', ss, ks, hidden, pretrained, dropout)
47
+
48
+ def vgg19_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
49
+ return Vgg('vgg19_bn', ss, ks, hidden, pretrained, dropout)
50
+
vietocr/vietocr/model/beam.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class Beam:
4
+
5
+ def __init__(self, beam_size=8, min_length=0, n_top=1, ranker=None,
6
+ start_token_id=1, end_token_id=2):
7
+ self.beam_size = beam_size
8
+ self.min_length = min_length
9
+ self.ranker = ranker
10
+
11
+ self.end_token_id = end_token_id
12
+ self.top_sentence_ended = False
13
+
14
+ self.prev_ks = []
15
+ self.next_ys = [torch.LongTensor(beam_size).fill_(start_token_id)] # remove padding
16
+
17
+ self.current_scores = torch.FloatTensor(beam_size).zero_()
18
+ self.all_scores = []
19
+
20
+ # Time and k pair for finished.
21
+ self.finished = []
22
+ self.n_top = n_top
23
+
24
+ self.ranker = ranker
25
+
26
+ def advance(self, next_log_probs):
27
+ # next_probs : beam_size X vocab_size
28
+
29
+ vocabulary_size = next_log_probs.size(1)
30
+ # current_beam_size = next_log_probs.size(0)
31
+
32
+ current_length = len(self.next_ys)
33
+ if current_length < self.min_length:
34
+ for beam_index in range(len(next_log_probs)):
35
+ next_log_probs[beam_index][self.end_token_id] = -1e10
36
+
37
+ if len(self.prev_ks) > 0:
38
+ beam_scores = next_log_probs + self.current_scores.unsqueeze(1).expand_as(next_log_probs)
39
+ # Don't let EOS have children.
40
+ last_y = self.next_ys[-1]
41
+ for beam_index in range(last_y.size(0)):
42
+ if last_y[beam_index] == self.end_token_id:
43
+ beam_scores[beam_index] = -1e10 # -1e20 raises error when executing
44
+ else:
45
+ beam_scores = next_log_probs[0]
46
+
47
+ flat_beam_scores = beam_scores.view(-1)
48
+ top_scores, top_score_ids = flat_beam_scores.topk(k=self.beam_size, dim=0, largest=True, sorted=True)
49
+
50
+ self.current_scores = top_scores
51
+ self.all_scores.append(self.current_scores)
52
+
53
+ prev_k = top_score_ids // vocabulary_size # (beam_size, )
54
+ next_y = top_score_ids - prev_k * vocabulary_size # (beam_size, )
55
+
56
+
57
+ self.prev_ks.append(prev_k)
58
+ self.next_ys.append(next_y)
59
+
60
+ for beam_index, last_token_id in enumerate(next_y):
61
+
62
+ if last_token_id == self.end_token_id:
63
+
64
+ # skip scoring
65
+ self.finished.append((self.current_scores[beam_index], len(self.next_ys) - 1, beam_index))
66
+
67
+ if next_y[0] == self.end_token_id:
68
+ self.top_sentence_ended = True
69
+
70
+ def get_current_state(self):
71
+ "Get the outputs for the current timestep."
72
+ return torch.stack(self.next_ys, dim=1)
73
+
74
+ def get_current_origin(self):
75
+ "Get the backpointers for the current timestep."
76
+ return self.prev_ks[-1]
77
+
78
+ def done(self):
79
+ return self.top_sentence_ended and len(self.finished) >= self.n_top
80
+
81
+ def get_hypothesis(self, timestep, k):
82
+ hypothesis = []
83
+ for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1):
84
+ hypothesis.append(self.next_ys[j + 1][k])
85
+ # for RNN, [:, k, :], and for trnasformer, [k, :, :]
86
+ k = self.prev_ks[j][k]
87
+
88
+ return hypothesis[::-1]
89
+
90
+ def sort_finished(self, minimum=None):
91
+ if minimum is not None:
92
+ i = 0
93
+ # Add from beam until we have minimum outputs.
94
+ while len(self.finished) < minimum:
95
+ # global_scores = self.global_scorer.score(self, self.scores)
96
+ # s = global_scores[i]
97
+ s = self.current_scores[i]
98
+ self.finished.append((s, len(self.next_ys) - 1, i))
99
+ i += 1
100
+
101
+ self.finished = sorted(self.finished, key=lambda a: a[0], reverse=True)
102
+ scores = [sc for sc, _, _ in self.finished]
103
+ ks = [(t, k) for _, t, k in self.finished]
104
+ return scores, ks
vietocr/vietocr/model/seqmodel/__init__.py ADDED
File without changes
vietocr/vietocr/model/seqmodel/convseq2seq.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+
6
+ class Encoder(nn.Module):
7
+ def __init__(self,
8
+ emb_dim,
9
+ hid_dim,
10
+ n_layers,
11
+ kernel_size,
12
+ dropout,
13
+ device,
14
+ max_length = 512):
15
+ super().__init__()
16
+
17
+ assert kernel_size % 2 == 1, "Kernel size must be odd!"
18
+
19
+ self.device = device
20
+
21
+ self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)
22
+
23
+ # self.tok_embedding = nn.Embedding(input_dim, emb_dim)
24
+ self.pos_embedding = nn.Embedding(max_length, emb_dim)
25
+
26
+ self.emb2hid = nn.Linear(emb_dim, hid_dim)
27
+ self.hid2emb = nn.Linear(hid_dim, emb_dim)
28
+
29
+ self.convs = nn.ModuleList([nn.Conv1d(in_channels = hid_dim,
30
+ out_channels = 2 * hid_dim,
31
+ kernel_size = kernel_size,
32
+ padding = (kernel_size - 1) // 2)
33
+ for _ in range(n_layers)])
34
+
35
+ self.dropout = nn.Dropout(dropout)
36
+
37
+ def forward(self, src):
38
+
39
+ #src = [batch size, src len]
40
+
41
+ src = src.transpose(0, 1)
42
+
43
+ batch_size = src.shape[0]
44
+ src_len = src.shape[1]
45
+ device = src.device
46
+
47
+ #create position tensor
48
+ pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(device)
49
+
50
+ #pos = [0, 1, 2, 3, ..., src len - 1]
51
+
52
+ #pos = [batch size, src len]
53
+
54
+ #embed tokens and positions
55
+
56
+ # tok_embedded = self.tok_embedding(src)
57
+ tok_embedded = src
58
+
59
+ pos_embedded = self.pos_embedding(pos)
60
+
61
+ #tok_embedded = pos_embedded = [batch size, src len, emb dim]
62
+
63
+ #combine embeddings by elementwise summing
64
+ embedded = self.dropout(tok_embedded + pos_embedded)
65
+
66
+ #embedded = [batch size, src len, emb dim]
67
+
68
+ #pass embedded through linear layer to convert from emb dim to hid dim
69
+ conv_input = self.emb2hid(embedded)
70
+
71
+ #conv_input = [batch size, src len, hid dim]
72
+
73
+ #permute for convolutional layer
74
+ conv_input = conv_input.permute(0, 2, 1)
75
+
76
+ #conv_input = [batch size, hid dim, src len]
77
+
78
+ #begin convolutional blocks...
79
+
80
+ for i, conv in enumerate(self.convs):
81
+
82
+ #pass through convolutional layer
83
+ conved = conv(self.dropout(conv_input))
84
+
85
+ #conved = [batch size, 2 * hid dim, src len]
86
+
87
+ #pass through GLU activation function
88
+ conved = F.glu(conved, dim = 1)
89
+
90
+ #conved = [batch size, hid dim, src len]
91
+
92
+ #apply residual connection
93
+ conved = (conved + conv_input) * self.scale
94
+
95
+ #conved = [batch size, hid dim, src len]
96
+
97
+ #set conv_input to conved for next loop iteration
98
+ conv_input = conved
99
+
100
+ #...end convolutional blocks
101
+
102
+ #permute and convert back to emb dim
103
+ conved = self.hid2emb(conved.permute(0, 2, 1))
104
+
105
+ #conved = [batch size, src len, emb dim]
106
+
107
+ #elementwise sum output (conved) and input (embedded) to be used for attention
108
+ combined = (conved + embedded) * self.scale
109
+
110
+ #combined = [batch size, src len, emb dim]
111
+
112
+ return conved, combined
113
+
114
+ class Decoder(nn.Module):
115
+ def __init__(self,
116
+ output_dim,
117
+ emb_dim,
118
+ hid_dim,
119
+ n_layers,
120
+ kernel_size,
121
+ dropout,
122
+ trg_pad_idx,
123
+ device,
124
+ max_length = 512):
125
+ super().__init__()
126
+
127
+ self.kernel_size = kernel_size
128
+ self.trg_pad_idx = trg_pad_idx
129
+ self.device = device
130
+
131
+ self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)
132
+
133
+ self.tok_embedding = nn.Embedding(output_dim, emb_dim)
134
+ self.pos_embedding = nn.Embedding(max_length, emb_dim)
135
+
136
+ self.emb2hid = nn.Linear(emb_dim, hid_dim)
137
+ self.hid2emb = nn.Linear(hid_dim, emb_dim)
138
+
139
+ self.attn_hid2emb = nn.Linear(hid_dim, emb_dim)
140
+ self.attn_emb2hid = nn.Linear(emb_dim, hid_dim)
141
+
142
+ self.fc_out = nn.Linear(emb_dim, output_dim)
143
+
144
+ self.convs = nn.ModuleList([nn.Conv1d(in_channels = hid_dim,
145
+ out_channels = 2 * hid_dim,
146
+ kernel_size = kernel_size)
147
+ for _ in range(n_layers)])
148
+
149
+ self.dropout = nn.Dropout(dropout)
150
+
151
+ def calculate_attention(self, embedded, conved, encoder_conved, encoder_combined):
152
+
153
+ #embedded = [batch size, trg len, emb dim]
154
+ #conved = [batch size, hid dim, trg len]
155
+ #encoder_conved = encoder_combined = [batch size, src len, emb dim]
156
+
157
+ #permute and convert back to emb dim
158
+ conved_emb = self.attn_hid2emb(conved.permute(0, 2, 1))
159
+
160
+ #conved_emb = [batch size, trg len, emb dim]
161
+
162
+ combined = (conved_emb + embedded) * self.scale
163
+
164
+ #combined = [batch size, trg len, emb dim]
165
+
166
+ energy = torch.matmul(combined, encoder_conved.permute(0, 2, 1))
167
+
168
+ #energy = [batch size, trg len, src len]
169
+
170
+ attention = F.softmax(energy, dim=2)
171
+
172
+ #attention = [batch size, trg len, src len]
173
+
174
+ attended_encoding = torch.matmul(attention, encoder_combined)
175
+
176
+ #attended_encoding = [batch size, trg len, emd dim]
177
+
178
+ #convert from emb dim -> hid dim
179
+ attended_encoding = self.attn_emb2hid(attended_encoding)
180
+
181
+ #attended_encoding = [batch size, trg len, hid dim]
182
+
183
+ #apply residual connection
184
+ attended_combined = (conved + attended_encoding.permute(0, 2, 1)) * self.scale
185
+
186
+ #attended_combined = [batch size, hid dim, trg len]
187
+
188
+ return attention, attended_combined
189
+
190
+ def forward(self, trg, encoder_conved, encoder_combined):
191
+
192
+ #trg = [batch size, trg len]
193
+ #encoder_conved = encoder_combined = [batch size, src len, emb dim]
194
+ trg = trg.transpose(0, 1)
195
+
196
+ batch_size = trg.shape[0]
197
+ trg_len = trg.shape[1]
198
+ device = trg.device
199
+
200
+ #create position tensor
201
+ pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(device)
202
+
203
+ #pos = [batch size, trg len]
204
+
205
+ #embed tokens and positions
206
+ tok_embedded = self.tok_embedding(trg)
207
+ pos_embedded = self.pos_embedding(pos)
208
+
209
+ #tok_embedded = [batch size, trg len, emb dim]
210
+ #pos_embedded = [batch size, trg len, emb dim]
211
+
212
+ #combine embeddings by elementwise summing
213
+ embedded = self.dropout(tok_embedded + pos_embedded)
214
+
215
+ #embedded = [batch size, trg len, emb dim]
216
+
217
+ #pass embedded through linear layer to go through emb dim -> hid dim
218
+ conv_input = self.emb2hid(embedded)
219
+
220
+ #conv_input = [batch size, trg len, hid dim]
221
+
222
+ #permute for convolutional layer
223
+ conv_input = conv_input.permute(0, 2, 1)
224
+
225
+ #conv_input = [batch size, hid dim, trg len]
226
+
227
+ batch_size = conv_input.shape[0]
228
+ hid_dim = conv_input.shape[1]
229
+
230
+ for i, conv in enumerate(self.convs):
231
+
232
+ #apply dropout
233
+ conv_input = self.dropout(conv_input)
234
+
235
+ #need to pad so decoder can't "cheat"
236
+ padding = torch.zeros(batch_size,
237
+ hid_dim,
238
+ self.kernel_size - 1).fill_(self.trg_pad_idx).to(device)
239
+
240
+ padded_conv_input = torch.cat((padding, conv_input), dim = 2)
241
+
242
+ #padded_conv_input = [batch size, hid dim, trg len + kernel size - 1]
243
+
244
+ #pass through convolutional layer
245
+ conved = conv(padded_conv_input)
246
+
247
+ #conved = [batch size, 2 * hid dim, trg len]
248
+
249
+ #pass through GLU activation function
250
+ conved = F.glu(conved, dim = 1)
251
+
252
+ #conved = [batch size, hid dim, trg len]
253
+
254
+ #calculate attention
255
+ attention, conved = self.calculate_attention(embedded,
256
+ conved,
257
+ encoder_conved,
258
+ encoder_combined)
259
+
260
+ #attention = [batch size, trg len, src len]
261
+
262
+ #apply residual connection
263
+ conved = (conved + conv_input) * self.scale
264
+
265
+ #conved = [batch size, hid dim, trg len]
266
+
267
+ #set conv_input to conved for next loop iteration
268
+ conv_input = conved
269
+
270
+ conved = self.hid2emb(conved.permute(0, 2, 1))
271
+
272
+ #conved = [batch size, trg len, emb dim]
273
+
274
+ output = self.fc_out(self.dropout(conved))
275
+
276
+ #output = [batch size, trg len, output dim]
277
+
278
+ return output, attention
279
+
280
+ class ConvSeq2Seq(nn.Module):
281
+ def __init__(self, vocab_size, emb_dim, hid_dim, enc_layers, dec_layers, enc_kernel_size, dec_kernel_size, enc_max_length, dec_max_length, dropout, pad_idx, device):
282
+ super().__init__()
283
+
284
+ enc = Encoder(emb_dim, hid_dim, enc_layers, enc_kernel_size, dropout, device, enc_max_length)
285
+ dec = Decoder(vocab_size, emb_dim, hid_dim, dec_layers, dec_kernel_size, dropout, pad_idx, device, dec_max_length)
286
+
287
+ self.encoder = enc
288
+ self.decoder = dec
289
+
290
+ def forward_encoder(self, src):
291
+ encoder_conved, encoder_combined = self.encoder(src)
292
+
293
+ return encoder_conved, encoder_combined
294
+
295
+ def forward_decoder(self, trg, memory):
296
+ encoder_conved, encoder_combined = memory
297
+ output, attention = self.decoder(trg, encoder_conved, encoder_combined)
298
+
299
+ return output, (encoder_conved, encoder_combined)
300
+
301
+ def forward(self, src, trg):
302
+
303
+ #src = [batch size, src len]
304
+ #trg = [batch size, trg len - 1] (<eos> token sliced off the end)
305
+
306
+ #calculate z^u (encoder_conved) and (z^u + e) (encoder_combined)
307
+ #encoder_conved is output from final encoder conv. block
308
+ #encoder_combined is encoder_conved plus (elementwise) src embedding plus
309
+ # positional embeddings
310
+ encoder_conved, encoder_combined = self.encoder(src)
311
+
312
+ #encoder_conved = [batch size, src len, emb dim]
313
+ #encoder_combined = [batch size, src len, emb dim]
314
+
315
+ #calculate predictions of next words
316
+ #output is a batch of predictions for each word in the trg sentence
317
+ #attention a batch of attention scores across the src sentence for
318
+ # each word in the trg sentence
319
+ output, attention = self.decoder(trg, encoder_conved, encoder_combined)
320
+
321
+ #output = [batch size, trg len - 1, output dim]
322
+ #attention = [batch size, trg len - 1, src len]
323
+
324
+ return output#, attention
vietocr/vietocr/model/seqmodel/seq2seq.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+
6
+ class Encoder(nn.Module):
7
+ def __init__(self, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
8
+ super().__init__()
9
+
10
+ self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
11
+ self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
12
+ self.dropout = nn.Dropout(dropout)
13
+
14
+ def forward(self, src):
15
+ """
16
+ src: src_len x batch_size x img_channel
17
+ outputs: src_len x batch_size x hid_dim
18
+ hidden: batch_size x hid_dim
19
+ """
20
+
21
+ embedded = self.dropout(src)
22
+
23
+ outputs, hidden = self.rnn(embedded)
24
+
25
+ hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
26
+
27
+ return outputs, hidden
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(self, enc_hid_dim, dec_hid_dim):
31
+ super().__init__()
32
+
33
+ self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
34
+ self.v = nn.Linear(dec_hid_dim, 1, bias = False)
35
+
36
+ def forward(self, hidden, encoder_outputs):
37
+ """
38
+ hidden: batch_size x hid_dim
39
+ encoder_outputs: src_len x batch_size x hid_dim,
40
+ outputs: batch_size x src_len
41
+ """
42
+
43
+ batch_size = encoder_outputs.shape[1]
44
+ src_len = encoder_outputs.shape[0]
45
+
46
+ hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
47
+
48
+ encoder_outputs = encoder_outputs.permute(1, 0, 2)
49
+
50
+ energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2)))
51
+
52
+ attention = self.v(energy).squeeze(2)
53
+
54
+ return F.softmax(attention, dim = 1)
55
+
56
+ class Decoder(nn.Module):
57
+ def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
58
+ super().__init__()
59
+
60
+ self.output_dim = output_dim
61
+ self.attention = attention
62
+
63
+ self.embedding = nn.Embedding(output_dim, emb_dim)
64
+ self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
65
+ self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
66
+ self.dropout = nn.Dropout(dropout)
67
+
68
+ def forward(self, input, hidden, encoder_outputs):
69
+ """
70
+ inputs: batch_size
71
+ hidden: batch_size x hid_dim
72
+ encoder_outputs: src_len x batch_size x hid_dim
73
+ """
74
+
75
+ input = input.unsqueeze(0)
76
+
77
+ embedded = self.dropout(self.embedding(input))
78
+
79
+ a = self.attention(hidden, encoder_outputs)
80
+
81
+ a = a.unsqueeze(1)
82
+
83
+ encoder_outputs = encoder_outputs.permute(1, 0, 2)
84
+
85
+ weighted = torch.bmm(a, encoder_outputs)
86
+
87
+ weighted = weighted.permute(1, 0, 2)
88
+
89
+ rnn_input = torch.cat((embedded, weighted), dim = 2)
90
+
91
+ output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
92
+
93
+ assert (output == hidden).all()
94
+
95
+ embedded = embedded.squeeze(0)
96
+ output = output.squeeze(0)
97
+ weighted = weighted.squeeze(0)
98
+
99
+ prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
100
+
101
+ return prediction, hidden.squeeze(0), a.squeeze(1)
102
+
103
+ class Seq2Seq(nn.Module):
104
+ def __init__(self, vocab_size, encoder_hidden, decoder_hidden, img_channel, decoder_embedded, dropout=0.1):
105
+ super().__init__()
106
+
107
+ attn = Attention(encoder_hidden, decoder_hidden)
108
+
109
+ self.encoder = Encoder(img_channel, encoder_hidden, decoder_hidden, dropout)
110
+ self.decoder = Decoder(vocab_size, decoder_embedded, encoder_hidden, decoder_hidden, dropout, attn)
111
+
112
+ def forward_encoder(self, src):
113
+ """
114
+ src: timestep x batch_size x channel
115
+ hidden: batch_size x hid_dim
116
+ encoder_outputs: src_len x batch_size x hid_dim
117
+ """
118
+
119
+ encoder_outputs, hidden = self.encoder(src)
120
+
121
+ return (hidden, encoder_outputs)
122
+
123
+ def forward_decoder(self, tgt, memory):
124
+ """
125
+ tgt: timestep x batch_size
126
+ hidden: batch_size x hid_dim
127
+ encouder: src_len x batch_size x hid_dim
128
+ output: batch_size x 1 x vocab_size
129
+ """
130
+
131
+ tgt = tgt[-1]
132
+ hidden, encoder_outputs = memory
133
+ output, hidden, _ = self.decoder(tgt, hidden, encoder_outputs)
134
+ output = output.unsqueeze(1)
135
+
136
+ return output, (hidden, encoder_outputs)
137
+
138
+ def forward(self, src, trg):
139
+ """
140
+ src: time_step x batch_size
141
+ trg: time_step x batch_size
142
+ outputs: batch_size x time_step x vocab_size
143
+ """
144
+
145
+ batch_size = src.shape[1]
146
+ trg_len = trg.shape[0]
147
+ trg_vocab_size = self.decoder.output_dim
148
+ device = src.device
149
+
150
+ outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(device)
151
+ encoder_outputs, hidden = self.encoder(src)
152
+
153
+ for t in range(trg_len):
154
+ input = trg[t]
155
+ output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
156
+
157
+ outputs[t] = output
158
+
159
+ outputs = outputs.transpose(0, 1).contiguous()
160
+
161
+ return outputs
162
+
163
+ def expand_memory(self, memory, beam_size):
164
+ hidden, encoder_outputs = memory
165
+ hidden = hidden.repeat(beam_size, 1)
166
+ encoder_outputs = encoder_outputs.repeat(1, beam_size, 1)
167
+
168
+ return (hidden, encoder_outputs)
169
+
170
+ def get_memory(self, memory, i):
171
+ hidden, encoder_outputs = memory
172
+ hidden = hidden[[i]]
173
+ encoder_outputs = encoder_outputs[:, [i],:]
174
+
175
+ return (hidden, encoder_outputs)
vietocr/vietocr/model/seqmodel/transformer.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+ from torchvision import models
3
+ import math
4
+ import torch
5
+ from torch import nn
6
+
7
+ class LanguageTransformer(nn.Module):
8
+ def __init__(self, vocab_size,
9
+ d_model, nhead,
10
+ num_encoder_layers, num_decoder_layers,
11
+ dim_feedforward, max_seq_length,
12
+ pos_dropout, trans_dropout):
13
+ super().__init__()
14
+
15
+ self.d_model = d_model
16
+ self.embed_tgt = nn.Embedding(vocab_size, d_model)
17
+ self.pos_enc = PositionalEncoding(d_model, pos_dropout, max_seq_length)
18
+ # self.learned_pos_enc = LearnedPositionalEncoding(d_model, pos_dropout, max_seq_length)
19
+
20
+ self.transformer = nn.Transformer(d_model, nhead,
21
+ num_encoder_layers, num_decoder_layers,
22
+ dim_feedforward, trans_dropout)
23
+
24
+ self.fc = nn.Linear(d_model, vocab_size)
25
+
26
+ def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
27
+ """
28
+ Shape:
29
+ - src: (W, N, C)
30
+ - tgt: (T, N)
31
+ - src_key_padding_mask: (N, S)
32
+ - tgt_key_padding_mask: (N, T)
33
+ - memory_key_padding_mask: (N, S)
34
+ - output: (N, T, E)
35
+
36
+ """
37
+ tgt_mask = self.gen_nopeek_mask(tgt.shape[0]).to(src.device)
38
+
39
+ src = self.pos_enc(src*math.sqrt(self.d_model))
40
+ # src = self.learned_pos_enc(src*math.sqrt(self.d_model))
41
+
42
+ tgt = self.pos_enc(self.embed_tgt(tgt) * math.sqrt(self.d_model))
43
+
44
+ output = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask,
45
+ tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
46
+ # output = rearrange(output, 't n e -> n t e')
47
+ output = output.transpose(0, 1)
48
+ return self.fc(output)
49
+
50
+ def gen_nopeek_mask(self, length):
51
+ mask = (torch.triu(torch.ones(length, length)) == 1).transpose(0, 1)
52
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
53
+
54
+ return mask
55
+
56
+ def forward_encoder(self, src):
57
+ src = self.pos_enc(src*math.sqrt(self.d_model))
58
+ memory = self.transformer.encoder(src)
59
+ return memory
60
+
61
+ def forward_decoder(self, tgt, memory):
62
+ tgt_mask = self.gen_nopeek_mask(tgt.shape[0]).to(tgt.device)
63
+ tgt = self.pos_enc(self.embed_tgt(tgt) * math.sqrt(self.d_model))
64
+
65
+ output = self.transformer.decoder(tgt, memory, tgt_mask=tgt_mask)
66
+ # output = rearrange(output, 't n e -> n t e')
67
+ output = output.transpose(0, 1)
68
+
69
+ return self.fc(output), memory
70
+
71
+ def expand_memory(self, memory, beam_size):
72
+ memory = memory.repeat(1, beam_size, 1)
73
+ return memory
74
+
75
+ def get_memory(self, memory, i):
76
+ memory = memory[:, [i], :]
77
+ return memory
78
+
79
+ class PositionalEncoding(nn.Module):
80
+ def __init__(self, d_model, dropout=0.1, max_len=100):
81
+ super(PositionalEncoding, self).__init__()
82
+ self.dropout = nn.Dropout(p=dropout)
83
+
84
+ pe = torch.zeros(max_len, d_model)
85
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
86
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
87
+ pe[:, 0::2] = torch.sin(position * div_term)
88
+ pe[:, 1::2] = torch.cos(position * div_term)
89
+ pe = pe.unsqueeze(0).transpose(0, 1)
90
+ self.register_buffer('pe', pe)
91
+
92
+ def forward(self, x):
93
+ x = x + self.pe[:x.size(0), :]
94
+
95
+ return self.dropout(x)
96
+
97
+ class LearnedPositionalEncoding(nn.Module):
98
+ def __init__(self, d_model, dropout=0.1, max_len=100):
99
+ super(LearnedPositionalEncoding, self).__init__()
100
+ self.dropout = nn.Dropout(p=dropout)
101
+
102
+ self.pos_embed = nn.Embedding(max_len, d_model)
103
+ self.layernorm = LayerNorm(d_model)
104
+
105
+ def forward(self, x):
106
+ seq_len = x.size(0)
107
+ pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
108
+ pos = pos.unsqueeze(-1).expand(x.size()[:2])
109
+ x = x + self.pos_embed(pos)
110
+ return self.dropout(self.layernorm(x))
111
+
112
+ class LayerNorm(nn.Module):
113
+ "A layernorm module in the TF style (epsilon inside the square root)."
114
+ def __init__(self, d_model, variance_epsilon=1e-12):
115
+ super().__init__()
116
+ self.gamma = nn.Parameter(torch.ones(d_model))
117
+ self.beta = nn.Parameter(torch.zeros(d_model))
118
+ self.variance_epsilon = variance_epsilon
119
+
120
+ def forward(self, x):
121
+ u = x.mean(-1, keepdim=True)
122
+ s = (x - u).pow(2).mean(-1, keepdim=True)
123
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
124
+ return self.gamma * x + self.beta
vietocr/vietocr/model/trainer.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vietocr.vietocr.optim.optim import ScheduledOptim
2
+ from vietocr.vietocr.optim.labelsmoothingloss import LabelSmoothingLoss
3
+ from torch.optim import Adam, SGD, AdamW
4
+ from torch import nn
5
+ from vietocr.vietocr.tool.translate import build_model
6
+ from vietocr.vietocr.tool.translate import translate, batch_translate_beam_search
7
+ from vietocr.vietocr.tool.utils import download_weights
8
+ from vietocr.vietocr.tool.logger import Logger
9
+ from vietocr.vietocr.loader.aug import ImgAugTransform
10
+
11
+ import yaml
12
+ import torch
13
+ from vietocr.vietocr.loader.dataloader_v1 import DataGen
14
+ from vietocr.vietocr.loader.dataloader import OCRDataset, ClusterRandomSampler, Collator
15
+ from torch.utils.data import DataLoader
16
+ from einops import rearrange
17
+ from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR, OneCycleLR
18
+
19
+ import torchvision
20
+
21
+ from vietocr.vietocr.tool.utils import compute_accuracy
22
+ from PIL import Image
23
+ import numpy as np
24
+ import os
25
+ import matplotlib.pyplot as plt
26
+ import time
27
+
28
+ class Trainer():
29
+ def __init__(self, config, pretrained=True, augmentor=ImgAugTransform()):
30
+
31
+ self.config = config
32
+ self.model, self.vocab = build_model(config)
33
+
34
+ self.device = config['device']
35
+ self.num_iters = config['trainer']['iters']
36
+ self.beamsearch = config['predictor']['beamsearch']
37
+
38
+ self.data_root = config['dataset']['data_root']
39
+ self.train_annotation = config['dataset']['train_annotation']
40
+ self.valid_annotation = config['dataset']['valid_annotation']
41
+ self.dataset_name = config['dataset']['name']
42
+
43
+ self.batch_size = config['trainer']['batch_size']
44
+ self.print_every = config['trainer']['print_every']
45
+ self.valid_every = config['trainer']['valid_every']
46
+
47
+ self.image_aug = config['aug']['image_aug']
48
+ self.masked_language_model = config['aug']['masked_language_model']
49
+
50
+ self.checkpoint = config['trainer']['checkpoint']
51
+ self.export_weights = config['trainer']['export']
52
+ self.metrics = config['trainer']['metrics']
53
+ logger = config['trainer']['log']
54
+
55
+ if logger:
56
+ self.logger = Logger(logger)
57
+
58
+ if pretrained:
59
+ weight_file = download_weights(config['pretrain'], quiet=config['quiet'])
60
+ self.load_weights(weight_file)
61
+
62
+ self.iter = 0
63
+
64
+ self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09)
65
+ self.scheduler = OneCycleLR(self.optimizer, total_steps=self.num_iters, **config['optimizer'])
66
+ # self.optimizer = ScheduledOptim(
67
+ # Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
68
+ # #config['transformer']['d_model'],
69
+ # 512,
70
+ # **config['optimizer'])
71
+
72
+ self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1)
73
+
74
+ transforms = None
75
+ if self.image_aug:
76
+ transforms = augmentor
77
+
78
+ self.train_gen = self.data_gen('train_{}'.format(self.dataset_name),
79
+ self.data_root, self.train_annotation, self.masked_language_model, transform=transforms)
80
+ if self.valid_annotation:
81
+ self.valid_gen = self.data_gen('valid_{}'.format(self.dataset_name),
82
+ self.data_root, self.valid_annotation, masked_language_model=False)
83
+
84
+ self.train_losses = []
85
+
86
+ def train(self):
87
+ total_loss = 0
88
+
89
+ total_loader_time = 0
90
+ total_gpu_time = 0
91
+ best_acc = 0
92
+
93
+ data_iter = iter(self.train_gen)
94
+ for i in range(self.num_iters):
95
+ self.iter += 1
96
+
97
+ start = time.time()
98
+
99
+ try:
100
+ batch = next(data_iter)
101
+ except StopIteration:
102
+ data_iter = iter(self.train_gen)
103
+ batch = next(data_iter)
104
+
105
+ total_loader_time += time.time() - start
106
+
107
+ start = time.time()
108
+ loss = self.step(batch)
109
+ total_gpu_time += time.time() - start
110
+
111
+ total_loss += loss
112
+ self.train_losses.append((self.iter, loss))
113
+
114
+ if self.iter % self.print_every == 0:
115
+ info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(self.iter,
116
+ total_loss/self.print_every, self.optimizer.param_groups[0]['lr'],
117
+ total_loader_time, total_gpu_time)
118
+
119
+ total_loss = 0
120
+ total_loader_time = 0
121
+ total_gpu_time = 0
122
+ print(info)
123
+ self.logger.log(info)
124
+
125
+ if self.valid_annotation and self.iter % self.valid_every == 0:
126
+ val_loss = self.validate()
127
+ acc_full_seq, acc_per_char = self.precision(self.metrics)
128
+
129
+ info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f}'.format(self.iter, val_loss, acc_full_seq, acc_per_char)
130
+ print(info)
131
+ self.logger.log(info)
132
+
133
+ if acc_full_seq > best_acc:
134
+ self.save_weights(self.export_weights)
135
+ best_acc = acc_full_seq
136
+
137
+
138
+ def validate(self):
139
+ self.model.eval()
140
+
141
+ total_loss = []
142
+
143
+ with torch.no_grad():
144
+ for step, batch in enumerate(self.valid_gen):
145
+ batch = self.batch_to_device(batch)
146
+ img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch['tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']
147
+
148
+ outputs = self.model(img, tgt_input, tgt_padding_mask)
149
+ # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
150
+
151
+ outputs = outputs.flatten(0,1)
152
+ tgt_output = tgt_output.flatten()
153
+ loss = self.criterion(outputs, tgt_output)
154
+
155
+ total_loss.append(loss.item())
156
+
157
+ del outputs
158
+ del loss
159
+
160
+ total_loss = np.mean(total_loss)
161
+ self.model.train()
162
+
163
+ return total_loss
164
+
165
+ def predict(self, sample=None):
166
+ pred_sents = []
167
+ actual_sents = []
168
+ img_files = []
169
+
170
+ for batch in self.valid_gen:
171
+ batch = self.batch_to_device(batch)
172
+
173
+ if self.beamsearch:
174
+ translated_sentence = batch_translate_beam_search(batch['img'], self.model)
175
+ prob = None
176
+ else:
177
+ translated_sentence, prob = translate(batch['img'], self.model)
178
+
179
+ pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
180
+ actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist())
181
+
182
+ img_files.extend(batch['filenames'])
183
+
184
+ pred_sents.extend(pred_sent)
185
+ actual_sents.extend(actual_sent)
186
+
187
+ if sample != None and len(pred_sents) > sample:
188
+ break
189
+
190
+ return pred_sents, actual_sents, img_files, prob
191
+
192
+ def precision(self, sample=None):
193
+
194
+ pred_sents, actual_sents, _, _ = self.predict(sample=sample)
195
+
196
+ acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode='full_sequence')
197
+ acc_per_char = compute_accuracy(actual_sents, pred_sents, mode='per_char')
198
+
199
+ return acc_full_seq, acc_per_char
200
+
201
+ def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16):
202
+
203
+ pred_sents, actual_sents, img_files, probs = self.predict(sample)
204
+
205
+ if errorcase:
206
+ wrongs = []
207
+ for i in range(len(img_files)):
208
+ if pred_sents[i]!= actual_sents[i]:
209
+ wrongs.append(i)
210
+
211
+ pred_sents = [pred_sents[i] for i in wrongs]
212
+ actual_sents = [actual_sents[i] for i in wrongs]
213
+ img_files = [img_files[i] for i in wrongs]
214
+ probs = [probs[i] for i in wrongs]
215
+
216
+ img_files = img_files[:sample]
217
+
218
+ fontdict = {
219
+ 'family':fontname,
220
+ 'size':fontsize
221
+ }
222
+
223
+ for vis_idx in range(0, len(img_files)):
224
+ img_path = img_files[vis_idx]
225
+ pred_sent = pred_sents[vis_idx]
226
+ actual_sent = actual_sents[vis_idx]
227
+ prob = probs[vis_idx]
228
+
229
+ img = Image.open(open(img_path, 'rb'))
230
+ plt.figure()
231
+ plt.imshow(img)
232
+ plt.title('prob: {:.3f} - pred: {} - actual: {}'.format(prob, pred_sent, actual_sent), loc='left', fontdict=fontdict)
233
+ plt.axis('off')
234
+
235
+ plt.show()
236
+
237
+ def visualize_dataset(self, sample=16, fontname='serif'):
238
+ n = 0
239
+ for batch in self.train_gen:
240
+ for i in range(self.batch_size):
241
+ img = batch['img'][i].numpy().transpose(1,2,0)
242
+ sent = self.vocab.decode(batch['tgt_input'].T[i].tolist())
243
+
244
+ plt.figure()
245
+ plt.title('sent: {}'.format(sent), loc='center', fontname=fontname)
246
+ plt.imshow(img)
247
+ plt.axis('off')
248
+
249
+ n += 1
250
+ if n >= sample:
251
+ plt.show()
252
+ return
253
+
254
+
255
+ def load_checkpoint(self, filename):
256
+ checkpoint = torch.load(filename)
257
+
258
+ optim = ScheduledOptim(
259
+ Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
260
+ self.config['transformer']['d_model'], **self.config['optimizer'])
261
+
262
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
263
+ self.model.load_state_dict(checkpoint['state_dict'])
264
+ self.iter = checkpoint['iter']
265
+
266
+ self.train_losses = checkpoint['train_losses']
267
+
268
+ def save_checkpoint(self, filename):
269
+ state = {'iter':self.iter, 'state_dict': self.model.state_dict(),
270
+ 'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses}
271
+
272
+ path, _ = os.path.split(filename)
273
+ os.makedirs(path, exist_ok=True)
274
+
275
+ torch.save(state, filename)
276
+
277
+ def load_weights(self, filename):
278
+ state_dict = torch.load(filename, map_location=torch.device(self.device))
279
+
280
+ for name, param in self.model.named_parameters():
281
+ if name not in state_dict:
282
+ print('{} not found'.format(name))
283
+ elif state_dict[name].shape != param.shape:
284
+ print('{} missmatching shape, required {} but found {}'.format(name, param.shape, state_dict[name].shape))
285
+ del state_dict[name]
286
+
287
+ self.model.load_state_dict(state_dict, strict=False)
288
+
289
+ def save_weights(self, filename):
290
+ path, _ = os.path.split(filename)
291
+ os.makedirs(path, exist_ok=True)
292
+
293
+ torch.save(self.model.state_dict(), filename)
294
+
295
+ def batch_to_device(self, batch):
296
+ img = batch['img'].to(self.device, non_blocking=True)
297
+ tgt_input = batch['tgt_input'].to(self.device, non_blocking=True)
298
+ tgt_output = batch['tgt_output'].to(self.device, non_blocking=True)
299
+ tgt_padding_mask = batch['tgt_padding_mask'].to(self.device, non_blocking=True)
300
+
301
+ batch = {
302
+ 'img': img, 'tgt_input':tgt_input,
303
+ 'tgt_output':tgt_output, 'tgt_padding_mask':tgt_padding_mask,
304
+ 'filenames': batch['filenames']
305
+ }
306
+
307
+ return batch
308
+
309
+ def data_gen(self, lmdb_path, data_root, annotation, masked_language_model=True, transform=None):
310
+ dataset = OCRDataset(lmdb_path=lmdb_path,
311
+ root_dir=data_root, annotation_path=annotation,
312
+ vocab=self.vocab, transform=transform,
313
+ image_height=self.config['dataset']['image_height'],
314
+ image_min_width=self.config['dataset']['image_min_width'],
315
+ image_max_width=self.config['dataset']['image_max_width'])
316
+
317
+ sampler = ClusterRandomSampler(dataset, self.batch_size, True)
318
+ collate_fn = Collator(masked_language_model)
319
+
320
+ gen = DataLoader(
321
+ dataset,
322
+ batch_size=self.batch_size,
323
+ sampler=sampler,
324
+ collate_fn = collate_fn,
325
+ shuffle=False,
326
+ drop_last=False,
327
+ **self.config['dataloader'])
328
+
329
+ return gen
330
+
331
+ def data_gen_v1(self, lmdb_path, data_root, annotation):
332
+ data_gen = DataGen(data_root, annotation, self.vocab, 'cpu',
333
+ image_height = self.config['dataset']['image_height'],
334
+ image_min_width = self.config['dataset']['image_min_width'],
335
+ image_max_width = self.config['dataset']['image_max_width'])
336
+
337
+ return data_gen
338
+
339
+ def step(self, batch):
340
+ self.model.train()
341
+
342
+ batch = self.batch_to_device(batch)
343
+ img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch['tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']
344
+
345
+ outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask)
346
+ # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
347
+ outputs = outputs.view(-1, outputs.size(2))#flatten(0, 1)
348
+ tgt_output = tgt_output.view(-1)#flatten()
349
+
350
+ loss = self.criterion(outputs, tgt_output)
351
+
352
+ self.optimizer.zero_grad()
353
+
354
+ loss.backward()
355
+
356
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
357
+
358
+ self.optimizer.step()
359
+ self.scheduler.step()
360
+
361
+ loss_item = loss.item()
362
+
363
+ return loss_item
vietocr/vietocr/model/transformerocr.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vietocr.vietocr.model.backbone.cnn import CNN
2
+ from vietocr.vietocr.model.seqmodel.transformer import LanguageTransformer
3
+ from vietocr.vietocr.model.seqmodel.seq2seq import Seq2Seq
4
+ from vietocr.vietocr.model.seqmodel.convseq2seq import ConvSeq2Seq
5
+ from torch import nn
6
+
7
+ class VietOCR(nn.Module):
8
+ def __init__(self, vocab_size,
9
+ backbone,
10
+ cnn_args,
11
+ transformer_args,
12
+ seq_modeling='transformer'):
13
+
14
+ super(VietOCR, self).__init__()
15
+
16
+ self.cnn = CNN(backbone, **cnn_args)
17
+ self.seq_modeling = seq_modeling
18
+
19
+ if seq_modeling == 'transformer':
20
+ self.transformer = LanguageTransformer(vocab_size, **transformer_args)
21
+ elif seq_modeling == 'seq2seq':
22
+ self.transformer = Seq2Seq(vocab_size, **transformer_args)
23
+ elif seq_modeling == 'convseq2seq':
24
+ self.transformer = ConvSeq2Seq(vocab_size, **transformer_args)
25
+ else:
26
+ raise('Not Support Seq Model')
27
+
28
+ def forward(self, img, tgt_input, tgt_key_padding_mask):
29
+ """
30
+ Shape:
31
+ - img: (N, C, H, W)
32
+ - tgt_input: (T, N)
33
+ - tgt_key_padding_mask: (N, T)
34
+ - output: b t v
35
+ """
36
+ src = self.cnn(img)
37
+
38
+ if self.seq_modeling == 'transformer':
39
+ outputs = self.transformer(src, tgt_input, tgt_key_padding_mask=tgt_key_padding_mask)
40
+ elif self.seq_modeling == 'seq2seq':
41
+ outputs = self.transformer(src, tgt_input)
42
+ elif self.seq_modeling == 'convseq2seq':
43
+ outputs = self.transformer(src, tgt_input)
44
+ return outputs
45
+
vietocr/vietocr/model/vocab.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Vocab():
2
+ def __init__(self, chars):
3
+ self.pad = 0
4
+ self.go = 1
5
+ self.eos = 2
6
+ self.mask_token = 3
7
+
8
+ self.chars = chars
9
+
10
+ self.c2i = {c:i+4 for i, c in enumerate(chars)}
11
+
12
+ self.i2c = {i+4:c for i, c in enumerate(chars)}
13
+
14
+ self.i2c[0] = '<pad>'
15
+ self.i2c[1] = '<sos>'
16
+ self.i2c[2] = '<eos>'
17
+ self.i2c[3] = '*'
18
+
19
+ def encode(self, chars):
20
+ return [self.go] + [self.c2i[c] for c in chars] + [self.eos]
21
+
22
+ def decode(self, ids):
23
+ first = 1 if self.go in ids else 0
24
+ last = ids.index(self.eos) if self.eos in ids else None
25
+ sent = ''.join([self.i2c[i] for i in ids[first:last]])
26
+ return sent
27
+
28
+ def __len__(self):
29
+ return len(self.c2i) + 4
30
+
31
+ def batch_decode(self, arr):
32
+ texts = [self.decode(ids) for ids in arr]
33
+ return texts
34
+
35
+ def __str__(self):
36
+ return self.chars
vietocr/vietocr/optim/__init__.py ADDED
File without changes
vietocr/vietocr/optim/labelsmoothingloss.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class LabelSmoothingLoss(nn.Module):
5
+ def __init__(self, classes, padding_idx, smoothing=0.0, dim=-1):
6
+ super(LabelSmoothingLoss, self).__init__()
7
+ self.confidence = 1.0 - smoothing
8
+ self.smoothing = smoothing
9
+ self.cls = classes
10
+ self.dim = dim
11
+ self.padding_idx = padding_idx
12
+
13
+ def forward(self, pred, target):
14
+ pred = pred.log_softmax(dim=self.dim)
15
+ with torch.no_grad():
16
+ # true_dist = pred.data.clone()
17
+ true_dist = torch.zeros_like(pred)
18
+ true_dist.fill_(self.smoothing / (self.cls - 2))
19
+ true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
20
+ true_dist[:, self.padding_idx] = 0
21
+ mask = torch.nonzero(target.data == self.padding_idx, as_tuple=False)
22
+ if mask.dim() > 0:
23
+ true_dist.index_fill_(0, mask.squeeze(), 0.0)
24
+
25
+ return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
vietocr/vietocr/optim/optim.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ScheduledOptim():
2
+ '''A simple wrapper class for learning rate scheduling'''
3
+
4
+ def __init__(self, optimizer, d_model, init_lr, n_warmup_steps):
5
+ assert n_warmup_steps > 0, 'must be greater than 0'
6
+
7
+ self._optimizer = optimizer
8
+ self.init_lr = init_lr
9
+ self.d_model = d_model
10
+ self.n_warmup_steps = n_warmup_steps
11
+ self.n_steps = 0
12
+
13
+
14
+ def step(self):
15
+ "Step with the inner optimizer"
16
+ self._update_learning_rate()
17
+ self._optimizer.step()
18
+
19
+
20
+ def zero_grad(self):
21
+ "Zero out the gradients with the inner optimizer"
22
+ self._optimizer.zero_grad()
23
+
24
+
25
+ def _get_lr_scale(self):
26
+ d_model = self.d_model
27
+ n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
28
+ return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5))
29
+
30
+ def state_dict(self):
31
+ optimizer_state_dict = {
32
+ 'init_lr':self.init_lr,
33
+ 'd_model':self.d_model,
34
+ 'n_warmup_steps':self.n_warmup_steps,
35
+ 'n_steps':self.n_steps,
36
+ '_optimizer':self._optimizer.state_dict(),
37
+ }
38
+
39
+ return optimizer_state_dict
40
+
41
+ def load_state_dict(self, state_dict):
42
+ self.init_lr = state_dict['init_lr']
43
+ self.d_model = state_dict['d_model']
44
+ self.n_warmup_steps = state_dict['n_warmup_steps']
45
+ self.n_steps = state_dict['n_steps']
46
+
47
+ self._optimizer.load_state_dict(state_dict['_optimizer'])
48
+
49
+ def _update_learning_rate(self):
50
+ ''' Learning rate scheduling per step '''
51
+
52
+ self.n_steps += 1
53
+
54
+ for param_group in self._optimizer.param_groups:
55
+ lr = self.init_lr*self._get_lr_scale()
56
+ self.lr = lr
57
+
58
+ param_group['lr'] = lr
vietocr/vietocr/tool/__init__.py ADDED
File without changes
vietocr/vietocr/tool/config.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from vietocr.vietocr.tool.utils import download_config
3
+
4
+ url_config = {
5
+ 'vgg_transformer':'vgg-transformer.yml',
6
+ 'resnet_transformer':'resnet_transformer.yml',
7
+ 'resnet_fpn_transformer':'resnet_fpn_transformer.yml',
8
+ 'vgg_seq2seq':'vgg-seq2seq.yml',
9
+ 'vgg_convseq2seq':'vgg_convseq2seq.yml',
10
+ 'vgg_decoderseq2seq':'vgg_decoderseq2seq.yml',
11
+ 'base':'base.yml',
12
+ }
13
+
14
+ class Cfg(dict):
15
+ def __init__(self, config_dict):
16
+ super(Cfg, self).__init__(**config_dict)
17
+ self.__dict__ = self
18
+
19
+ @staticmethod
20
+ def load_config_from_file(fname):
21
+ #base_config = download_config(url_config['base'])
22
+ base_config = {}
23
+ with open(fname, encoding='utf-8') as f:
24
+ config = yaml.safe_load(f)
25
+ base_config.update(config)
26
+
27
+ return Cfg(base_config)
28
+
29
+ @staticmethod
30
+ def load_config_from_name(name):
31
+ base_config = download_config(url_config['base'])
32
+ config = download_config(url_config[name])
33
+
34
+ base_config.update(config)
35
+ return Cfg(base_config)
36
+
37
+ def save(self, fname):
38
+ with open(fname, 'w') as outfile:
39
+ yaml.dump(dict(self), outfile, default_flow_style=False, allow_unicode=True)
40
+
vietocr/vietocr/tool/create_dataset.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import lmdb # install lmdb by "pip install lmdb"
4
+ import cv2
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+ def checkImageIsValid(imageBin):
9
+ isvalid = True
10
+ imgH = None
11
+ imgW = None
12
+
13
+ imageBuf = np.fromstring(imageBin, dtype=np.uint8)
14
+ try:
15
+ img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
16
+
17
+ imgH, imgW = img.shape[0], img.shape[1]
18
+ if imgH * imgW == 0:
19
+ isvalid = False
20
+ except Exception as e:
21
+ isvalid = False
22
+
23
+ return isvalid, imgH, imgW
24
+
25
+ def writeCache(env, cache):
26
+ with env.begin(write=True) as txn:
27
+ for k, v in cache.items():
28
+ txn.put(k.encode(), v)
29
+
30
+ def createDataset(outputPath, root_dir, annotation_path):
31
+ """
32
+ Create LMDB dataset for CRNN training.
33
+ ARGS:
34
+ outputPath : LMDB output path
35
+ imagePathList : list of image path
36
+ labelList : list of corresponding groundtruth texts
37
+ lexiconList : (optional) list of lexicon lists
38
+ checkValid : if true, check the validity of every image
39
+ """
40
+
41
+ annotation_path = os.path.join(root_dir, annotation_path)
42
+ annotations = []
43
+ with open(annotation_path, 'r') as ann_file:
44
+ lines = ann_file.readlines()
45
+ # for l in lines:
46
+ # try:
47
+ # annotations.append(l.strip().split('\t'))
48
+ # except:
49
+ # pass
50
+ annotations = [l.strip().split('\t') for l in lines]
51
+
52
+ nSamples = len(annotations)
53
+ env = lmdb.open(outputPath, map_size=1099511627776)
54
+ cache = {}
55
+ cnt = 0
56
+ error = 0
57
+
58
+ pbar = tqdm(range(nSamples), ncols = 100, desc='Create {}'.format(outputPath))
59
+ for i in pbar:
60
+ if len(annotations[i]) >= 2:
61
+ imageFile, label = annotations[i]
62
+ else:
63
+ print("Error: Not enough values to unpack")
64
+ #sys.exit()
65
+
66
+ #imageFile, label = annotations[i]
67
+ imagePath = os.path.join(root_dir, imageFile)
68
+
69
+ if not os.path.exists(imagePath):
70
+ error += 1
71
+ continue
72
+
73
+ with open(imagePath, 'rb') as f:
74
+ imageBin = f.read()
75
+ isvalid, imgH, imgW = checkImageIsValid(imageBin)
76
+
77
+ if not isvalid:
78
+ error += 1
79
+ continue
80
+
81
+ imageKey = 'image-%09d' % cnt
82
+ labelKey = 'label-%09d' % cnt
83
+ pathKey = 'path-%09d' % cnt
84
+ dimKey = 'dim-%09d' % cnt
85
+
86
+ cache[imageKey] = imageBin
87
+ cache[labelKey] = label.encode()
88
+ cache[pathKey] = imageFile.encode()
89
+ cache[dimKey] = np.array([imgH, imgW], dtype=np.int32).tobytes()
90
+
91
+ cnt += 1
92
+
93
+ if cnt % 1000 == 0:
94
+ writeCache(env, cache)
95
+ cache = {}
96
+
97
+ nSamples = cnt-1
98
+ cache['num-samples'] = str(nSamples).encode()
99
+ writeCache(env, cache)
100
+
101
+ if error > 0:
102
+ print('Remove {} invalid images'.format(error))
103
+ print('Created dataset with %d samples' % nSamples)
104
+ sys.stdout.flush()
105
+
vietocr/vietocr/tool/logger.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ class Logger():
5
+ def __init__(self, fname):
6
+ path, _ = os.path.split(fname)
7
+ os.makedirs(path, exist_ok=True)
8
+
9
+ self.logger = open(fname, 'w')
10
+
11
+ def log(self, string):
12
+ self.logger.write(string+'\n')
13
+ self.logger.flush()
14
+
15
+ def close(self):
16
+ self.logger.close()
17
+
vietocr/vietocr/tool/predictor.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vietocr.vietocr.tool.translate import build_model, translate, translate_beam_search, process_input, predict, batch_translate_beam_search
2
+ from vietocr.vietocr.tool.utils import download_weights
3
+
4
+ import torch
5
+ from collections import defaultdict
6
+
7
+ class Predictor():
8
+ def __init__(self, config):
9
+
10
+ device = config['device']
11
+
12
+ model, vocab = build_model(config)
13
+ weights = '/tmp/weights.pth'
14
+
15
+ if config['weights'].startswith('http'):
16
+ weights = download_weights(config['weights'])
17
+ else:
18
+ weights = config['weights']
19
+
20
+ model.load_state_dict(torch.load(weights, map_location=torch.device(device)))
21
+
22
+ self.config = config
23
+ self.model = model
24
+ self.vocab = vocab
25
+ self.device = device
26
+
27
+ def predict(self, img, return_prob=False):
28
+ img = process_input(img, self.config['dataset']['image_height'],
29
+ self.config['dataset']['image_min_width'], self.config['dataset']['image_max_width'])
30
+ img = img.to(self.config['device'])
31
+
32
+ if self.config['predictor']['beamsearch']:
33
+ sent = translate_beam_search(img, self.model)
34
+ s = sent
35
+ prob = None
36
+ else:
37
+ s, prob = translate(img, self.model)
38
+ s = s[0].tolist()
39
+ prob = prob[0]
40
+
41
+ s = self.vocab.decode(s)
42
+
43
+ if return_prob:
44
+ return s, prob
45
+ else:
46
+ return s
47
+
48
+ def predict_batch(self, imgs, return_prob=False):
49
+ bucket = defaultdict(list)
50
+ bucket_idx = defaultdict(list)
51
+ bucket_pred = {}
52
+
53
+ sents, probs = [0]*len(imgs), [0]*len(imgs)
54
+
55
+ for i, img in enumerate(imgs):
56
+ img = process_input(img, self.config['dataset']['image_height'],
57
+ self.config['dataset']['image_min_width'], self.config['dataset']['image_max_width'])
58
+
59
+ bucket[img.shape[-1]].append(img)
60
+ bucket_idx[img.shape[-1]].append(i)
61
+
62
+
63
+ for k, batch in bucket.items():
64
+ batch = torch.cat(batch, 0).to(self.device)
65
+ s, prob = translate(batch, self.model)
66
+ prob = prob.tolist()
67
+
68
+ s = s.tolist()
69
+ s = self.vocab.batch_decode(s)
70
+
71
+ bucket_pred[k] = (s, prob)
72
+
73
+
74
+ for k in bucket_pred:
75
+ idx = bucket_idx[k]
76
+ sent, prob = bucket_pred[k]
77
+ for i, j in enumerate(idx):
78
+ sents[j] = sent[i]
79
+ probs[j] = prob[i]
80
+
81
+ if return_prob:
82
+ return sents, probs
83
+ else:
84
+ return sents
85
+
vietocr/vietocr/tool/translate.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import math
4
+ from PIL import Image
5
+ from torch.nn.functional import log_softmax, softmax
6
+
7
+ from vietocr.vietocr.model.transformerocr import VietOCR
8
+ from vietocr.vietocr.model.vocab import Vocab
9
+ from vietocr.vietocr.model.beam import Beam
10
+
11
+ def batch_translate_beam_search(img, model, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2):
12
+ # img: NxCxHxW
13
+ model.eval()
14
+ device = img.device
15
+ sents = []
16
+
17
+ with torch.no_grad():
18
+ src = model.cnn(img)
19
+ print(src.shap)
20
+ memories = model.transformer.forward_encoder(src)
21
+ for i in range(src.size(0)):
22
+ # memory = memories[:,i,:].repeat(1, beam_size, 1) # TxNxE
23
+ memory = model.transformer.get_memory(memories, i)
24
+ sent = beamsearch(memory, model, device, beam_size, candidates, max_seq_length, sos_token, eos_token)
25
+ sents.append(sent)
26
+
27
+ sents = np.asarray(sents)
28
+
29
+ return sents
30
+
31
+ def translate_beam_search(img, model, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2):
32
+ # img: 1xCxHxW
33
+ model.eval()
34
+ device = img.device
35
+
36
+ with torch.no_grad():
37
+ src = model.cnn(img)
38
+ memory = model.transformer.forward_encoder(src) #TxNxE
39
+ sent = beamsearch(memory, model, device, beam_size, candidates, max_seq_length, sos_token, eos_token)
40
+
41
+ return sent
42
+
43
+ def beamsearch(memory, model, device, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2):
44
+ # memory: Tx1xE
45
+ model.eval()
46
+
47
+ beam = Beam(beam_size=beam_size, min_length=0, n_top=candidates, ranker=None, start_token_id=sos_token, end_token_id=eos_token)
48
+
49
+ with torch.no_grad():
50
+ # memory = memory.repeat(1, beam_size, 1) # TxNxE
51
+ memory = model.transformer.expand_memory(memory, beam_size)
52
+
53
+ for _ in range(max_seq_length):
54
+
55
+ tgt_inp = beam.get_current_state().transpose(0,1).to(device) # TxN
56
+ decoder_outputs, memory = model.transformer.forward_decoder(tgt_inp, memory)
57
+
58
+ log_prob = log_softmax(decoder_outputs[:,-1, :].squeeze(0), dim=-1)
59
+ beam.advance(log_prob.cpu())
60
+
61
+ if beam.done():
62
+ break
63
+
64
+ scores, ks = beam.sort_finished(minimum=1)
65
+
66
+ hypothesises = []
67
+ for i, (times, k) in enumerate(ks[:candidates]):
68
+ hypothesis = beam.get_hypothesis(times, k)
69
+ hypothesises.append(hypothesis)
70
+
71
+ return [1] + [int(i) for i in hypothesises[0][:-1]]
72
+
73
+ def translate(img, model, max_seq_length=128, sos_token=1, eos_token=2):
74
+ "data: BxCXHxW"
75
+ model.eval()
76
+ device = img.device
77
+
78
+ with torch.no_grad():
79
+ src = model.cnn(img)
80
+ memory = model.transformer.forward_encoder(src)
81
+
82
+ translated_sentence = [[sos_token]*len(img)]
83
+ char_probs = [[1]*len(img)]
84
+
85
+ max_length = 0
86
+
87
+ while max_length <= max_seq_length and not all(np.any(np.asarray(translated_sentence).T==eos_token, axis=1)):
88
+
89
+ tgt_inp = torch.LongTensor(translated_sentence).to(device)
90
+
91
+ # output = model(img, tgt_inp, tgt_key_padding_mask=None)
92
+ # output = model.transformer(src, tgt_inp, tgt_key_padding_mask=None)
93
+ output, memory = model.transformer.forward_decoder(tgt_inp, memory)
94
+ output = softmax(output, dim=-1)
95
+ output = output.to('cpu')
96
+
97
+ values, indices = torch.topk(output, 5)
98
+
99
+ indices = indices[:, -1, 0]
100
+ indices = indices.tolist()
101
+
102
+ values = values[:, -1, 0]
103
+ values = values.tolist()
104
+ char_probs.append(values)
105
+
106
+ translated_sentence.append(indices)
107
+ max_length += 1
108
+
109
+ del output
110
+
111
+ translated_sentence = np.asarray(translated_sentence).T
112
+
113
+ char_probs = np.asarray(char_probs).T
114
+ char_probs = np.multiply(char_probs, translated_sentence>3)
115
+ char_probs = np.sum(char_probs, axis=-1)/(char_probs>0).sum(-1)
116
+
117
+ return translated_sentence, char_probs
118
+
119
+
120
+ def build_model(config):
121
+ vocab = Vocab(config['vocab'])
122
+ device = config['device']
123
+
124
+ model = VietOCR(len(vocab),
125
+ config['backbone'],
126
+ config['cnn'],
127
+ config['transformer'],
128
+ config['seq_modeling'])
129
+
130
+ model = model.to(device)
131
+
132
+ return model, vocab
133
+
134
+ def resize(w, h, expected_height, image_min_width, image_max_width):
135
+ new_w = int(expected_height * float(w) / float(h))
136
+ round_to = 10
137
+ new_w = math.ceil(new_w/round_to)*round_to
138
+ new_w = max(new_w, image_min_width)
139
+ new_w = min(new_w, image_max_width)
140
+
141
+ return new_w, expected_height
142
+
143
+ def process_image(image, image_height, image_min_width, image_max_width):
144
+ img = image.convert('RGB')
145
+
146
+ w, h = img.size
147
+ new_w, image_height = resize(w, h, image_height, image_min_width, image_max_width)
148
+
149
+ img = img.resize((new_w, image_height), Image.ANTIALIAS)
150
+
151
+ img = np.asarray(img).transpose(2,0, 1)
152
+ img = img/255
153
+ return img
154
+
155
+ def process_input(image, image_height, image_min_width, image_max_width):
156
+ img = process_image(image, image_height, image_min_width, image_max_width)
157
+ img = img[np.newaxis, ...]
158
+ img = torch.FloatTensor(img)
159
+ return img
160
+
161
+ def predict(filename, config):
162
+ img = Image.open(filename)
163
+ img = process_input(img)
164
+
165
+ img = img.to(config['device'])
166
+
167
+ model, vocab = build_model(config)
168
+ s = translate(img, model)[0].tolist()
169
+ s = vocab.decode(s)
170
+
171
+ return s
172
+
vietocr/vietocr/tool/utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gdown
3
+ import yaml
4
+ import numpy as np
5
+ import uuid
6
+ import requests
7
+ import tempfile
8
+ from tqdm import tqdm
9
+
10
+ def download_weights(uri, cached=None, md5=None, quiet=False):
11
+ if uri.startswith('http'):
12
+ return download(url=uri, quiet=quiet)
13
+ return uri
14
+
15
+ def download(url, quiet=False):
16
+ tmp_dir = tempfile.gettempdir()
17
+ filename = url.split('/')[-1]
18
+ full_path = os.path.join(tmp_dir, filename)
19
+
20
+ if os.path.exists(full_path):
21
+ print('Model weight {} exsits. Ignore download!'.format(full_path))
22
+ return full_path
23
+
24
+ with requests.get(url, stream=True) as r:
25
+ r.raise_for_status()
26
+ with open(full_path, 'wb') as f:
27
+ for chunk in tqdm(r.iter_content(chunk_size=8192)):
28
+ # If you have chunk encoded response uncomment if
29
+ # and set chunk_size parameter to None.
30
+ #if chunk:
31
+ f.write(chunk)
32
+ return full_path
33
+
34
+ def download_config(id):
35
+ url = 'https://vocr.vn/data/vietocr/config/{}'.format(id)
36
+ r = requests.get(url)
37
+ config = yaml.safe_load(r.text)
38
+ return config
39
+
40
+ def compute_accuracy(ground_truth, predictions, mode='full_sequence'):
41
+ """
42
+ Computes accuracy
43
+ :param ground_truth:
44
+ :param predictions:
45
+ :param display: Whether to print values to stdout
46
+ :param mode: if 'per_char' is selected then
47
+ single_label_accuracy = correct_predicted_char_nums_of_single_sample / single_label_char_nums
48
+ avg_label_accuracy = sum(single_label_accuracy) / label_nums
49
+ if 'full_sequence' is selected then
50
+ single_label_accuracy = 1 if the prediction result is exactly the same as label else 0
51
+ avg_label_accuracy = sum(single_label_accuracy) / label_nums
52
+ :return: avg_label_accuracy
53
+ """
54
+ if mode == 'per_char':
55
+
56
+ accuracy = []
57
+
58
+ for index, label in enumerate(ground_truth):
59
+ prediction = predictions[index]
60
+ total_count = len(label)
61
+ correct_count = 0
62
+ try:
63
+ for i, tmp in enumerate(label):
64
+ if tmp == prediction[i]:
65
+ correct_count += 1
66
+ except IndexError:
67
+ continue
68
+ finally:
69
+ try:
70
+ accuracy.append(correct_count / total_count)
71
+ except ZeroDivisionError:
72
+ if len(prediction) == 0:
73
+ accuracy.append(1)
74
+ else:
75
+ accuracy.append(0)
76
+ avg_accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
77
+ elif mode == 'full_sequence':
78
+ try:
79
+ correct_count = 0
80
+ for index, label in enumerate(ground_truth):
81
+ prediction = predictions[index]
82
+ if prediction == label:
83
+ correct_count += 1
84
+ avg_accuracy = correct_count / len(ground_truth)
85
+ except ZeroDivisionError:
86
+ if not predictions:
87
+ avg_accuracy = 1
88
+ else:
89
+ avg_accuracy = 0
90
+ else:
91
+ raise NotImplementedError('Other accuracy compute mode has not been implemented')
92
+
93
+ return avg_accuracy