tomofi commited on
Commit
cb433d6
1 Parent(s): 3c2ad8a

Add application file

Browse files
LICENSE ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ABINet for non-commercial purposes
2
+
3
+ Copyright (c) 2021, USTC
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -1,13 +1,138 @@
1
- ---
2
- title: ABINet OCR
3
- emoji: 🏃
4
- colorFrom: indigo
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 2.8.12
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition
2
+
3
+ The official code of [ABINet](https://arxiv.org/pdf/2103.06495.pdf) (CVPR 2021, Oral).
4
+
5
+ ABINet uses a vision model and an explicit language model to recognize text in the wild, which are trained in end-to-end way. The language model (BCN) achieves bidirectional language representation in simulating cloze test, additionally utilizing iterative correction strategy.
6
+
7
+ ![framework](./figs/framework.png)
8
+
9
+ ## Runtime Environment
10
+
11
+ - We provide a pre-built docker image using the Dockerfile from `docker/Dockerfile`
12
+
13
+ - Running in Docker
14
+ ```
15
+ $ git@github.com:FangShancheng/ABINet.git
16
+ $ docker run --gpus all --rm -ti --ipc=host -v $(pwd)/ABINet:/app fangshancheng/fastai:torch1.1 /bin/bash
17
+ ```
18
+ - (Untested) Or using the dependencies
19
+ ```
20
+ pip install -r requirements.txt
21
+ ```
22
+
23
+ ## Datasets
24
+
25
+ - Training datasets
26
+
27
+ 1. [MJSynth](http://www.robots.ox.ac.uk/~vgg/data/text/) (MJ):
28
+ - Use `tools/create_lmdb_dataset.py` to convert images into LMDB dataset
29
+ - [LMDB dataset BaiduNetdisk(passwd:n23k)](https://pan.baidu.com/s/1mgnTiyoR8f6Cm655rFI4HQ)
30
+ 2. [SynthText](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) (ST):
31
+ - Use `tools/crop_by_word_bb.py` to crop images from original [SynthText](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) dataset, and convert images into LMDB dataset by `tools/create_lmdb_dataset.py`
32
+ - [LMDB dataset BaiduNetdisk(passwd:n23k)](https://pan.baidu.com/s/1mgnTiyoR8f6Cm655rFI4HQ)
33
+ 3. [WikiText103](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip), which is only used for pre-trainig language models:
34
+ - Use `notebooks/prepare_wikitext103.ipynb` to convert text into CSV format.
35
+ - [CSV dataset BaiduNetdisk(passwd:dk01)](https://pan.baidu.com/s/1yabtnPYDKqhBb_Ie9PGFXA)
36
+
37
+ - Evaluation datasets, LMDB datasets can be downloaded from [BaiduNetdisk(passwd:1dbv)](https://pan.baidu.com/s/1RUg3Akwp7n8kZYJ55rU5LQ), [GoogleDrive](https://drive.google.com/file/d/1dTI0ipu14Q1uuK4s4z32DqbqF3dJPdkk/view?usp=sharing).
38
+ 1. ICDAR 2013 (IC13)
39
+ 2. ICDAR 2015 (IC15)
40
+ 3. IIIT5K Words (IIIT)
41
+ 4. Street View Text (SVT)
42
+ 5. Street View Text-Perspective (SVTP)
43
+ 6. CUTE80 (CUTE)
44
+
45
+
46
+ - The structure of `data` directory is
47
+ ```
48
+ data
49
+ ├── charset_36.txt
50
+ ├── evaluation
51
+ │   ├── CUTE80
52
+ │   ├── IC13_857
53
+ │   ├── IC15_1811
54
+ │   ├── IIIT5k_3000
55
+ │   ├── SVT
56
+ │   └── SVTP
57
+ ├── training
58
+ │   ├── MJ
59
+ │   │   ├── MJ_test
60
+ │   │   ├── MJ_train
61
+ │   │   └── MJ_valid
62
+ │   └── ST
63
+ ├── WikiText-103.csv
64
+ └── WikiText-103_eval_d1.csv
65
+ ```
66
+
67
+ ### Pretrained Models
68
+
69
+ Get the pretrained models from [BaiduNetdisk(passwd:kwck)](https://pan.baidu.com/s/1b3vyvPwvh_75FkPlp87czQ), [GoogleDrive](https://drive.google.com/file/d/1mYM_26qHUom_5NU7iutHneB_KHlLjL5y/view?usp=sharing). Performances of the pretrained models are summaried as follows:
70
+
71
+ |Model|IC13|SVT|IIIT|IC15|SVTP|CUTE|AVG|
72
+ |-|-|-|-|-|-|-|-|
73
+ |ABINet-SV|97.1|92.7|95.2|84.0|86.7|88.5|91.4|
74
+ |ABINet-LV|97.0|93.4|96.4|85.9|89.5|89.2|92.7|
75
+
76
+ ## Training
77
+
78
+ 1. Pre-train vision model
79
+ ```
80
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config=configs/pretrain_vision_model.yaml
81
+ ```
82
+ 2. Pre-train language model
83
+ ```
84
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config=configs/pretrain_language_model.yaml
85
+ ```
86
+ 3. Train ABINet
87
+ ```
88
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config=configs/train_abinet.yaml
89
+ ```
90
+ Note:
91
+ - You can set the `checkpoint` path for vision and language models separately for specific pretrained model, or set to `None` to train from scratch
92
+
93
+
94
+ ## Evaluation
95
+
96
+ ```
97
+ CUDA_VISIBLE_DEVICES=0 python main.py --config=configs/train_abinet.yaml --phase test --image_only
98
+ ```
99
+ Additional flags:
100
+ - `--checkpoint /path/to/checkpoint` set the path of evaluation model
101
+ - `--test_root /path/to/dataset` set the path of evaluation dataset
102
+ - `--model_eval [alignment|vision]` which sub-model to evaluate
103
+ - `--image_only` disable dumping visualization of attention masks
104
+
105
+ ## Run Demo
106
+
107
+ ```
108
+ python demo.py --config=configs/train_abinet.yaml --input=figs/test
109
+ ```
110
+ Additional flags:
111
+ - `--config /path/to/config` set the path of configuration file
112
+ - `--input /path/to/image-directory` set the path of image directory or wildcard path, e.g, `--input='figs/test/*.png'`
113
+ - `--checkpoint /path/to/checkpoint` set the path of trained model
114
+ - `--cuda [-1|0|1|2|3...]` set the cuda id, by default -1 is set and stands for cpu
115
+ - `--model_eval [alignment|vision]` which sub-model to use
116
+ - `--image_only` disable dumping visualization of attention masks
117
+
118
+ ## Visualization
119
+ Successful and failure cases on low-quality images:
120
+
121
+ ![cases](./figs/cases.png)
122
+
123
+ ## Citation
124
+ If you find our method useful for your reserach, please cite
125
+ ```bash
126
+ @article{fang2021read,
127
+ title={Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition},
128
+ author={Fang, Shancheng and Xie, Hongtao and Wang, Yuxin and Mao, Zhendong and Zhang, Yongdong},
129
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
130
+ year={2021}
131
+ }
132
+ ```
133
+
134
+ ## License
135
+
136
+ This project is only free for academic research purposes, licensed under the 2-clause BSD License - see the LICENSE file for details.
137
+
138
+ Feel free to contact fangsc@ustc.edu.cn if you have any questions.
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gdown
3
+ gdown.download(id='1mYM_26qHUom_5NU7iutHneB_KHlLjL5y', output='workdir.zip')
4
+ os.system('unzip workdir.zip')
5
+
6
+ import glob
7
+ import gradio as gr
8
+ from demo import get_model, preprocess, postprocess, load
9
+ from utils import Config, Logger, CharsetMapper
10
+
11
+ def process_image(image):
12
+ config = Config('configs/train_abinet.yaml')
13
+ config.model_vision_checkpoint = None
14
+ model = get_model(config)
15
+ model = load(model, 'workdir/train-abinet/best-train-abinet.pth')
16
+ charset = CharsetMapper(filename=config.dataset_charset_path, max_length=config.dataset_max_length + 1)
17
+
18
+ img = image.convert('RGB')
19
+ img = preprocess(img, config.dataset_image_width, config.dataset_image_height)
20
+ res = model(img)
21
+ return postprocess(res, charset, 'alignment')[0][0]
22
+
23
+ title = "Interactive demo: ABINet"
24
+ description = "Demo for ABINet, ABINet uses a vision model and an explicit language model to recognize text in the wild, which are trained in end-to-end way. The language model (BCN) achieves bidirectional language representation in simulating cloze test, additionally utilizing iterative correction strategy. To use it, simply upload a (single-text line) image or use one of the example images below and click 'submit'. Results will show up in a few seconds."
25
+ article = "<p style='text-align: center'><a href='https://arxiv.org/pdf/2103.06495.pdf'>Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition</a> | <a href='https://github.com/FangShancheng/ABINet'>Github Repo</a></p>"
26
+
27
+ iface = gr.Interface(fn=process_image,
28
+ inputs=gr.inputs.Image(type="pil"),
29
+ outputs=gr.outputs.Textbox(),
30
+ title=title,
31
+ description=description,
32
+ article=article,
33
+ examples=glob.glob('figs/test/*.png'))
34
+ iface.launch(debug=True)
callbacks.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import shutil
3
+ import time
4
+
5
+ import editdistance as ed
6
+ import torchvision.utils as vutils
7
+ from fastai.callbacks.tensorboard import (LearnerTensorboardWriter,
8
+ SummaryWriter, TBWriteRequest,
9
+ asyncTBWriter)
10
+ from fastai.vision import *
11
+ from torch.nn.parallel import DistributedDataParallel
12
+ from torchvision import transforms
13
+
14
+ import dataset
15
+ from utils import CharsetMapper, Timer, blend_mask
16
+
17
+
18
+ class IterationCallback(LearnerTensorboardWriter):
19
+ "A `TrackerCallback` that monitor in each iteration."
20
+ def __init__(self, learn:Learner, name:str='model', checpoint_keep_num=5,
21
+ show_iters:int=50, eval_iters:int=1000, save_iters:int=20000,
22
+ start_iters:int=0, stats_iters=20000):
23
+ #if self.learn.rank is not None: time.sleep(self.learn.rank) # keep all event files
24
+ super().__init__(learn, base_dir='.', name=learn.path, loss_iters=show_iters,
25
+ stats_iters=stats_iters, hist_iters=stats_iters)
26
+ self.name, self.bestname = Path(name).name, f'best-{Path(name).name}'
27
+ self.show_iters = show_iters
28
+ self.eval_iters = eval_iters
29
+ self.save_iters = save_iters
30
+ self.start_iters = start_iters
31
+ self.checpoint_keep_num = checpoint_keep_num
32
+ self.metrics_root = 'metrics/' # rewrite
33
+ self.timer = Timer()
34
+ self.host = self.learn.rank is None or self.learn.rank == 0
35
+
36
+ def _write_metrics(self, iteration:int, names:List[str], last_metrics:MetricsList)->None:
37
+ "Writes training metrics to Tensorboard."
38
+ for i, name in enumerate(names):
39
+ if last_metrics is None or len(last_metrics) < i+1: return
40
+ scalar_value = last_metrics[i]
41
+ self._write_scalar(name=name, scalar_value=scalar_value, iteration=iteration)
42
+
43
+ def _write_sub_loss(self, iteration:int, last_losses:dict)->None:
44
+ "Writes sub loss to Tensorboard."
45
+ for name, loss in last_losses.items():
46
+ scalar_value = to_np(loss)
47
+ tag = self.metrics_root + name
48
+ self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
49
+
50
+ def _save(self, name):
51
+ if isinstance(self.learn.model, DistributedDataParallel):
52
+ tmp = self.learn.model
53
+ self.learn.model = self.learn.model.module
54
+ self.learn.save(name)
55
+ self.learn.model = tmp
56
+ else: self.learn.save(name)
57
+
58
+ def _validate(self, dl=None, callbacks=None, metrics=None, keeped_items=False):
59
+ "Validate on `dl` with potential `callbacks` and `metrics`."
60
+ dl = ifnone(dl, self.learn.data.valid_dl)
61
+ metrics = ifnone(metrics, self.learn.metrics)
62
+ cb_handler = CallbackHandler(ifnone(callbacks, []), metrics)
63
+ cb_handler.on_train_begin(1, None, metrics); cb_handler.on_epoch_begin()
64
+ if keeped_items: cb_handler.state_dict.update(dict(keeped_items=[]))
65
+ val_metrics = validate(self.learn.model, dl, self.loss_func, cb_handler)
66
+ cb_handler.on_epoch_end(val_metrics)
67
+ if keeped_items: return cb_handler.state_dict['keeped_items']
68
+ else: return cb_handler.state_dict['last_metrics']
69
+
70
+ def jump_to_epoch_iter(self, epoch:int, iteration:int)->None:
71
+ try:
72
+ self.learn.load(f'{self.name}_{epoch}_{iteration}', purge=False)
73
+ logging.info(f'Loaded {self.name}_{epoch}_{iteration}')
74
+ except: logging.info(f'Model {self.name}_{epoch}_{iteration} not found.')
75
+
76
+ def on_train_begin(self, n_epochs, **kwargs):
77
+ # TODO: can not write graph here
78
+ # super().on_train_begin(**kwargs)
79
+ self.best = -float('inf')
80
+ self.timer.tic()
81
+ if self.host:
82
+ checkpoint_path = self.learn.path/'checkpoint.yaml'
83
+ if checkpoint_path.exists():
84
+ os.remove(checkpoint_path)
85
+ open(checkpoint_path, 'w').close()
86
+ return {'skip_validate': True, 'iteration':self.start_iters} # disable default validate
87
+
88
+ def on_batch_begin(self, **kwargs:Any)->None:
89
+ self.timer.toc_data()
90
+ super().on_batch_begin(**kwargs)
91
+
92
+ def on_batch_end(self, iteration, epoch, last_loss, smooth_loss, train, **kwargs):
93
+ super().on_batch_end(last_loss, iteration, train, **kwargs)
94
+ if iteration == 0: return
95
+
96
+ if iteration % self.loss_iters == 0:
97
+ last_losses = self.learn.loss_func.last_losses
98
+ self._write_sub_loss(iteration=iteration, last_losses=last_losses)
99
+ self.tbwriter.add_scalar(tag=self.metrics_root + 'lr',
100
+ scalar_value=self.opt.lr, global_step=iteration)
101
+
102
+ if iteration % self.show_iters == 0:
103
+ log_str = f'epoch {epoch} iter {iteration}: loss = {last_loss:6.4f}, ' \
104
+ f'smooth loss = {smooth_loss:6.4f}'
105
+ logging.info(log_str)
106
+ # log_str = f'data time = {self.timer.data_diff:.4f}s, runing time = {self.timer.running_diff:.4f}s'
107
+ # logging.info(log_str)
108
+
109
+ if iteration % self.eval_iters == 0:
110
+ # TODO: or remove time to on_epoch_end
111
+ # 1. Record time
112
+ log_str = f'average data time = {self.timer.average_data_time():.4f}s, ' \
113
+ f'average running time = {self.timer.average_running_time():.4f}s'
114
+ logging.info(log_str)
115
+
116
+ # 2. Call validate
117
+ last_metrics = self._validate()
118
+ self.learn.model.train()
119
+ log_str = f'epoch {epoch} iter {iteration}: eval loss = {last_metrics[0]:6.4f}, ' \
120
+ f'ccr = {last_metrics[1]:6.4f}, cwr = {last_metrics[2]:6.4f}, ' \
121
+ f'ted = {last_metrics[3]:6.4f}, ned = {last_metrics[4]:6.4f}, ' \
122
+ f'ted/w = {last_metrics[5]:6.4f}, '
123
+ logging.info(log_str)
124
+ names = ['eval_loss', 'ccr', 'cwr', 'ted', 'ned', 'ted/w']
125
+ self._write_metrics(iteration, names, last_metrics)
126
+
127
+ # 3. Save best model
128
+ current = last_metrics[2]
129
+ if current is not None and current > self.best:
130
+ logging.info(f'Better model found at epoch {epoch}, '\
131
+ f'iter {iteration} with accuracy value: {current:6.4f}.')
132
+ self.best = current
133
+ self._save(f'{self.bestname}')
134
+
135
+ if iteration % self.save_iters == 0 and self.host:
136
+ logging.info(f'Save model {self.name}_{epoch}_{iteration}')
137
+ filename = f'{self.name}_{epoch}_{iteration}'
138
+ self._save(filename)
139
+
140
+ checkpoint_path = self.learn.path/'checkpoint.yaml'
141
+ if not checkpoint_path.exists():
142
+ open(checkpoint_path, 'w').close()
143
+ with open(checkpoint_path, 'r') as file:
144
+ checkpoints = yaml.load(file, Loader=yaml.FullLoader) or dict()
145
+ checkpoints['all_checkpoints'] = (
146
+ checkpoints.get('all_checkpoints') or list())
147
+ checkpoints['all_checkpoints'].insert(0, filename)
148
+ if len(checkpoints['all_checkpoints']) > self.checpoint_keep_num:
149
+ removed_checkpoint = checkpoints['all_checkpoints'].pop()
150
+ removed_checkpoint = self.learn.path/self.learn.model_dir/f'{removed_checkpoint}.pth'
151
+ os.remove(removed_checkpoint)
152
+ checkpoints['current_checkpoint'] = filename
153
+ with open(checkpoint_path, 'w') as file:
154
+ yaml.dump(checkpoints, file)
155
+
156
+
157
+ self.timer.toc_running()
158
+
159
+ def on_train_end(self, **kwargs):
160
+ #self.learn.load(f'{self.bestname}', purge=False)
161
+ pass
162
+
163
+ def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs)->None:
164
+ self._write_embedding(iteration=iteration)
165
+
166
+
167
+ class TextAccuracy(Callback):
168
+ _names = ['ccr', 'cwr', 'ted', 'ned', 'ted/w']
169
+ def __init__(self, charset_path, max_length, case_sensitive, model_eval):
170
+ self.charset_path = charset_path
171
+ self.max_length = max_length
172
+ self.case_sensitive = case_sensitive
173
+ self.charset = CharsetMapper(charset_path, self.max_length)
174
+ self.names = self._names
175
+
176
+ self.model_eval = model_eval or 'alignment'
177
+ assert self.model_eval in ['vision', 'language', 'alignment']
178
+
179
+ def on_epoch_begin(self, **kwargs):
180
+ self.total_num_char = 0.
181
+ self.total_num_word = 0.
182
+ self.correct_num_char = 0.
183
+ self.correct_num_word = 0.
184
+ self.total_ed = 0.
185
+ self.total_ned = 0.
186
+
187
+ def _get_output(self, last_output):
188
+ if isinstance(last_output, (tuple, list)):
189
+ for res in last_output:
190
+ if res['name'] == self.model_eval: output = res
191
+ else: output = last_output
192
+ return output
193
+
194
+ def _update_output(self, last_output, items):
195
+ if isinstance(last_output, (tuple, list)):
196
+ for res in last_output:
197
+ if res['name'] == self.model_eval: res.update(items)
198
+ else: last_output.update(items)
199
+ return last_output
200
+
201
+ def on_batch_end(self, last_output, last_target, **kwargs):
202
+ output = self._get_output(last_output)
203
+ logits, pt_lengths = output['logits'], output['pt_lengths']
204
+ pt_text, pt_scores, pt_lengths_ = self.decode(logits)
205
+ assert (pt_lengths == pt_lengths_).all(), f'{pt_lengths} != {pt_lengths_} for {pt_text}'
206
+ last_output = self._update_output(last_output, {'pt_text':pt_text, 'pt_scores':pt_scores})
207
+
208
+ pt_text = [self.charset.trim(t) for t in pt_text]
209
+ label = last_target[0]
210
+ if label.dim() == 3: label = label.argmax(dim=-1) # one-hot label
211
+ gt_text = [self.charset.get_text(l, trim=True) for l in label]
212
+
213
+ for i in range(len(gt_text)):
214
+ if not self.case_sensitive:
215
+ gt_text[i], pt_text[i] = gt_text[i].lower(), pt_text[i].lower()
216
+ distance = ed.eval(gt_text[i], pt_text[i])
217
+ self.total_ed += distance
218
+ self.total_ned += float(distance) / max(len(gt_text[i]), 1)
219
+
220
+ if gt_text[i] == pt_text[i]:
221
+ self.correct_num_word += 1
222
+ self.total_num_word += 1
223
+
224
+ for j in range(min(len(gt_text[i]), len(pt_text[i]))):
225
+ if gt_text[i][j] == pt_text[i][j]:
226
+ self.correct_num_char += 1
227
+ self.total_num_char += len(gt_text[i])
228
+
229
+ return {'last_output': last_output}
230
+
231
+ def on_epoch_end(self, last_metrics, **kwargs):
232
+ mets = [self.correct_num_char / self.total_num_char,
233
+ self.correct_num_word / self.total_num_word,
234
+ self.total_ed,
235
+ self.total_ned,
236
+ self.total_ed / self.total_num_word]
237
+ return add_metrics(last_metrics, mets)
238
+
239
+ def decode(self, logit):
240
+ """ Greed decode """
241
+ # TODO: test running time and decode on GPU
242
+ out = F.softmax(logit, dim=2)
243
+ pt_text, pt_scores, pt_lengths = [], [], []
244
+ for o in out:
245
+ text = self.charset.get_text(o.argmax(dim=1), padding=False, trim=False)
246
+ text = text.split(self.charset.null_char)[0] # end at end-token
247
+ pt_text.append(text)
248
+ pt_scores.append(o.max(dim=1)[0])
249
+ pt_lengths.append(min(len(text) + 1, self.max_length)) # one for end-token
250
+ pt_scores = torch.stack(pt_scores)
251
+ pt_lengths = pt_scores.new_tensor(pt_lengths, dtype=torch.long)
252
+ return pt_text, pt_scores, pt_lengths
253
+
254
+
255
+ class TopKTextAccuracy(TextAccuracy):
256
+ _names = ['ccr', 'cwr']
257
+ def __init__(self, k, charset_path, max_length, case_sensitive, model_eval):
258
+ self.k = k
259
+ self.charset_path = charset_path
260
+ self.max_length = max_length
261
+ self.case_sensitive = case_sensitive
262
+ self.charset = CharsetMapper(charset_path, self.max_length)
263
+ self.names = self._names
264
+
265
+ def on_epoch_begin(self, **kwargs):
266
+ self.total_num_char = 0.
267
+ self.total_num_word = 0.
268
+ self.correct_num_char = 0.
269
+ self.correct_num_word = 0.
270
+
271
+ def on_batch_end(self, last_output, last_target, **kwargs):
272
+ logits, pt_lengths = last_output['logits'], last_output['pt_lengths']
273
+ gt_labels, gt_lengths = last_target[:]
274
+
275
+ for logit, pt_length, label, length in zip(logits, pt_lengths, gt_labels, gt_lengths):
276
+ word_flag = True
277
+ for i in range(length):
278
+ char_logit = logit[i].topk(self.k)[1]
279
+ char_label = label[i].argmax(-1)
280
+ if char_label in char_logit: self.correct_num_char += 1
281
+ else: word_flag = False
282
+ self.total_num_char += 1
283
+ if pt_length == length and word_flag:
284
+ self.correct_num_word += 1
285
+ self.total_num_word += 1
286
+
287
+ def on_epoch_end(self, last_metrics, **kwargs):
288
+ mets = [self.correct_num_char / self.total_num_char,
289
+ self.correct_num_word / self.total_num_word,
290
+ 0., 0., 0.]
291
+ return add_metrics(last_metrics, mets)
292
+
293
+
294
+ class DumpPrediction(LearnerCallback):
295
+
296
+ def __init__(self, learn, dataset, charset_path, model_eval, image_only=False, debug=False):
297
+ super().__init__(learn=learn)
298
+ self.debug = debug
299
+ self.model_eval = model_eval or 'alignment'
300
+ self.image_only = image_only
301
+ assert self.model_eval in ['vision', 'language', 'alignment']
302
+
303
+ self.dataset, self.root = dataset, Path(self.learn.path)/f'{dataset}-{self.model_eval}'
304
+ self.attn_root = self.root/'attn'
305
+ self.charset = CharsetMapper(charset_path)
306
+ if self.root.exists(): shutil.rmtree(self.root)
307
+ self.root.mkdir(), self.attn_root.mkdir()
308
+
309
+ self.pil = transforms.ToPILImage()
310
+ self.tensor = transforms.ToTensor()
311
+ size = self.learn.data.img_h, self.learn.data.img_w
312
+ self.resize = transforms.Resize(size=size, interpolation=0)
313
+ self.c = 0
314
+
315
+ def on_batch_end(self, last_input, last_output, last_target, **kwargs):
316
+ if isinstance(last_output, (tuple, list)):
317
+ for res in last_output:
318
+ if res['name'] == self.model_eval: pt_text = res['pt_text']
319
+ if res['name'] == 'vision': attn_scores = res['attn_scores'].detach().cpu()
320
+ if res['name'] == self.model_eval: logits = res['logits']
321
+ else:
322
+ pt_text = last_output['pt_text']
323
+ attn_scores = last_output['attn_scores'].detach().cpu()
324
+ logits = last_output['logits']
325
+
326
+ images = last_input[0] if isinstance(last_input, (tuple, list)) else last_input
327
+ images = images.detach().cpu()
328
+ pt_text = [self.charset.trim(t) for t in pt_text]
329
+ gt_label = last_target[0]
330
+ if gt_label.dim() == 3: gt_label = gt_label.argmax(dim=-1) # one-hot label
331
+ gt_text = [self.charset.get_text(l, trim=True) for l in gt_label]
332
+
333
+ prediction, false_prediction = [], []
334
+ for gt, pt, image, attn, logit in zip(gt_text, pt_text, images, attn_scores, logits):
335
+ prediction.append(f'{gt}\t{pt}\n')
336
+ if gt != pt:
337
+ if self.debug:
338
+ scores = torch.softmax(logit, dim=-1)[:max(len(pt), len(gt)) + 1]
339
+ logging.info(f'{self.c} gt {gt}, pt {pt}, logit {logit.shape}, scores {scores.topk(5, dim=-1)}')
340
+ false_prediction.append(f'{gt}\t{pt}\n')
341
+
342
+ image = self.learn.data.denorm(image)
343
+ if not self.image_only:
344
+ image_np = np.array(self.pil(image))
345
+ attn_pil = [self.pil(a) for a in attn[:, None, :, :]]
346
+ attn = [self.tensor(self.resize(a)).repeat(3, 1, 1) for a in attn_pil]
347
+ attn_sum = np.array([np.array(a) for a in attn_pil[:len(pt)]]).sum(axis=0)
348
+ blended_sum = self.tensor(blend_mask(image_np, attn_sum))
349
+ blended = [self.tensor(blend_mask(image_np, np.array(a))) for a in attn_pil]
350
+ save_image = torch.stack([image] + attn + [blended_sum] + blended)
351
+ save_image = save_image.view(2, -1, *save_image.shape[1:])
352
+ save_image = save_image.permute(1, 0, 2, 3, 4).flatten(0, 1)
353
+ vutils.save_image(save_image, self.attn_root/f'{self.c}_{gt}_{pt}.jpg',
354
+ nrow=2, normalize=True, scale_each=True)
355
+ else:
356
+ self.pil(image).save(self.attn_root/f'{self.c}_{gt}_{pt}.jpg')
357
+ self.c += 1
358
+
359
+ with open(self.root/f'{self.model_eval}.txt', 'a') as f: f.writelines(prediction)
360
+ with open(self.root/f'{self.model_eval}-false.txt', 'a') as f: f.writelines(false_prediction)
configs/pretrain_language_model.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ name: pretrain-language-model
3
+ phase: train
4
+ stage: pretrain-language
5
+ workdir: workdir
6
+ seed: ~
7
+
8
+ dataset:
9
+ train: {
10
+ roots: ['data/WikiText-103.csv'],
11
+ batch_size: 4096
12
+ }
13
+ test: {
14
+ roots: ['data/WikiText-103_eval_d1.csv'],
15
+ batch_size: 4096
16
+ }
17
+
18
+ training:
19
+ epochs: 80
20
+ show_iters: 50
21
+ eval_iters: 6000
22
+ save_iters: 3000
23
+
24
+ optimizer:
25
+ type: Adam
26
+ true_wd: False
27
+ wd: 0.0
28
+ bn_wd: False
29
+ clip_grad: 20
30
+ lr: 0.0001
31
+ args: {
32
+ betas: !!python/tuple [0.9, 0.999], # for default Adam
33
+ }
34
+ scheduler: {
35
+ periods: [70, 10],
36
+ gamma: 0.1,
37
+ }
38
+
39
+ model:
40
+ name: 'modules.model_language.BCNLanguage'
41
+ language: {
42
+ num_layers: 4,
43
+ loss_weight: 1.,
44
+ use_self_attn: False
45
+ }
configs/pretrain_vision_model.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ name: pretrain-vision-model
3
+ phase: train
4
+ stage: pretrain-vision
5
+ workdir: workdir
6
+ seed: ~
7
+
8
+ dataset:
9
+ train: {
10
+ roots: ['data/training/MJ/MJ_train/',
11
+ 'data/training/MJ/MJ_test/',
12
+ 'data/training/MJ/MJ_valid/',
13
+ 'data/training/ST'],
14
+ batch_size: 384
15
+ }
16
+ test: {
17
+ roots: ['data/evaluation/IIIT5k_3000',
18
+ 'data/evaluation/SVT',
19
+ 'data/evaluation/SVTP',
20
+ 'data/evaluation/IC13_857',
21
+ 'data/evaluation/IC15_1811',
22
+ 'data/evaluation/CUTE80'],
23
+ batch_size: 384
24
+ }
25
+ data_aug: True
26
+ multiscales: False
27
+ num_workers: 14
28
+
29
+ training:
30
+ epochs: 8
31
+ show_iters: 50
32
+ eval_iters: 3000
33
+ save_iters: 3000
34
+
35
+ optimizer:
36
+ type: Adam
37
+ true_wd: False
38
+ wd: 0.0
39
+ bn_wd: False
40
+ clip_grad: 20
41
+ lr: 0.0001
42
+ args: {
43
+ betas: !!python/tuple [0.9, 0.999], # for default Adam
44
+ }
45
+ scheduler: {
46
+ periods: [6, 2],
47
+ gamma: 0.1,
48
+ }
49
+
50
+ model:
51
+ name: 'modules.model_vision.BaseVision'
52
+ checkpoint: ~
53
+ vision: {
54
+ loss_weight: 1.,
55
+ attention: 'position',
56
+ backbone: 'transformer',
57
+ backbone_ln: 3,
58
+ }
configs/pretrain_vision_model_sv.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ name: pretrain-vision-model-sv
3
+ phase: train
4
+ stage: pretrain-vision
5
+ workdir: workdir
6
+ seed: ~
7
+
8
+ dataset:
9
+ train: {
10
+ roots: ['data/training/MJ/MJ_train/',
11
+ 'data/training/MJ/MJ_test/',
12
+ 'data/training/MJ/MJ_valid/',
13
+ 'data/training/ST'],
14
+ batch_size: 384
15
+ }
16
+ test: {
17
+ roots: ['data/evaluation/IIIT5k_3000',
18
+ 'data/evaluation/SVT',
19
+ 'data/evaluation/SVTP',
20
+ 'data/evaluation/IC13_857',
21
+ 'data/evaluation/IC15_1811',
22
+ 'data/evaluation/CUTE80'],
23
+ batch_size: 384
24
+ }
25
+ data_aug: True
26
+ multiscales: False
27
+ num_workers: 14
28
+
29
+ training:
30
+ epochs: 8
31
+ show_iters: 50
32
+ eval_iters: 3000
33
+ save_iters: 3000
34
+
35
+ optimizer:
36
+ type: Adam
37
+ true_wd: False
38
+ wd: 0.0
39
+ bn_wd: False
40
+ clip_grad: 20
41
+ lr: 0.0001
42
+ args: {
43
+ betas: !!python/tuple [0.9, 0.999], # for default Adam
44
+ }
45
+ scheduler: {
46
+ periods: [6, 2],
47
+ gamma: 0.1,
48
+ }
49
+
50
+ model:
51
+ name: 'modules.model_vision.BaseVision'
52
+ checkpoint: ~
53
+ vision: {
54
+ loss_weight: 1.,
55
+ attention: 'attention',
56
+ backbone: 'transformer',
57
+ backbone_ln: 2,
58
+ }
configs/template.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ name: exp
3
+ phase: train
4
+ stage: pretrain-vision
5
+ workdir: /tmp/workdir
6
+ seed: ~
7
+
8
+ dataset:
9
+ train: {
10
+ roots: ['data/training/MJ/MJ_train/',
11
+ 'data/training/MJ/MJ_test/',
12
+ 'data/training/MJ/MJ_valid/',
13
+ 'data/training/ST'],
14
+ batch_size: 128
15
+ }
16
+ test: {
17
+ roots: ['data/evaluation/IIIT5k_3000',
18
+ 'data/evaluation/SVT',
19
+ 'data/evaluation/SVTP',
20
+ 'data/evaluation/IC13_857',
21
+ 'data/evaluation/IC15_1811',
22
+ 'data/evaluation/CUTE80'],
23
+ batch_size: 128
24
+ }
25
+ charset_path: data/charset_36.txt
26
+ num_workers: 4
27
+ max_length: 25 # 30
28
+ image_height: 32
29
+ image_width: 128
30
+ case_sensitive: False
31
+ eval_case_sensitive: False
32
+ data_aug: True
33
+ multiscales: False
34
+ pin_memory: True
35
+ smooth_label: False
36
+ smooth_factor: 0.1
37
+ one_hot_y: True
38
+ use_sm: False
39
+
40
+ training:
41
+ epochs: 6
42
+ show_iters: 50
43
+ eval_iters: 3000
44
+ save_iters: 20000
45
+ start_iters: 0
46
+ stats_iters: 100000
47
+
48
+ optimizer:
49
+ type: Adadelta # Adadelta, Adam
50
+ true_wd: False
51
+ wd: 0. # 0.001
52
+ bn_wd: False
53
+ args: {
54
+ # betas: !!python/tuple [0.9, 0.99], # betas=(0.9,0.99) for AdamW
55
+ # betas: !!python/tuple [0.9, 0.999], # for default Adam
56
+ }
57
+ clip_grad: 20
58
+ lr: [1.0, 1.0, 1.0] # lr: [0.005, 0.005, 0.005]
59
+ scheduler: {
60
+ periods: [3, 2, 1],
61
+ gamma: 0.1,
62
+ }
63
+
64
+ model:
65
+ name: 'modules.model_abinet.ABINetModel'
66
+ checkpoint: ~
67
+ strict: True
configs/train_abinet.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ name: train-abinet
3
+ phase: train
4
+ stage: train-super
5
+ workdir: workdir
6
+ seed: ~
7
+
8
+ dataset:
9
+ train: {
10
+ roots: ['data/training/MJ/MJ_train/',
11
+ 'data/training/MJ/MJ_test/',
12
+ 'data/training/MJ/MJ_valid/',
13
+ 'data/training/ST'],
14
+ batch_size: 384
15
+ }
16
+ test: {
17
+ roots: ['data/evaluation/IIIT5k_3000',
18
+ 'data/evaluation/SVT',
19
+ 'data/evaluation/SVTP',
20
+ 'data/evaluation/IC13_857',
21
+ 'data/evaluation/IC15_1811',
22
+ 'data/evaluation/CUTE80'],
23
+ batch_size: 384
24
+ }
25
+ data_aug: True
26
+ multiscales: False
27
+ num_workers: 14
28
+
29
+ training:
30
+ epochs: 10
31
+ show_iters: 50
32
+ eval_iters: 3000
33
+ save_iters: 3000
34
+
35
+ optimizer:
36
+ type: Adam
37
+ true_wd: False
38
+ wd: 0.0
39
+ bn_wd: False
40
+ clip_grad: 20
41
+ lr: 0.0001
42
+ args: {
43
+ betas: !!python/tuple [0.9, 0.999], # for default Adam
44
+ }
45
+ scheduler: {
46
+ periods: [6, 4],
47
+ gamma: 0.1,
48
+ }
49
+
50
+ model:
51
+ name: 'modules.model_abinet_iter.ABINetIterModel'
52
+ iter_size: 3
53
+ ensemble: ''
54
+ use_vision: False
55
+ vision: {
56
+ checkpoint: workdir/pretrain-vision-model/best-pretrain-vision-model.pth,
57
+ loss_weight: 1.,
58
+ attention: 'position',
59
+ backbone: 'transformer',
60
+ backbone_ln: 3,
61
+ }
62
+ language: {
63
+ checkpoint: workdir/pretrain-language-model/pretrain-language-model.pth,
64
+ num_layers: 4,
65
+ loss_weight: 1.,
66
+ detach: True,
67
+ use_self_attn: False
68
+ }
69
+ alignment: {
70
+ loss_weight: 1.,
71
+ }
configs/train_abinet_sv.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ name: train-abinet-sv
3
+ phase: train
4
+ stage: train-super
5
+ workdir: workdir
6
+ seed: ~
7
+
8
+ dataset:
9
+ train: {
10
+ roots: ['data/training/MJ/MJ_train/',
11
+ 'data/training/MJ/MJ_test/',
12
+ 'data/training/MJ/MJ_valid/',
13
+ 'data/training/ST'],
14
+ batch_size: 384
15
+ }
16
+ test: {
17
+ roots: ['data/evaluation/IIIT5k_3000',
18
+ 'data/evaluation/SVT',
19
+ 'data/evaluation/SVTP',
20
+ 'data/evaluation/IC13_857',
21
+ 'data/evaluation/IC15_1811',
22
+ 'data/evaluation/CUTE80'],
23
+ batch_size: 384
24
+ }
25
+ data_aug: True
26
+ multiscales: False
27
+ num_workers: 14
28
+
29
+ training:
30
+ epochs: 10
31
+ show_iters: 50
32
+ eval_iters: 3000
33
+ save_iters: 3000
34
+
35
+ optimizer:
36
+ type: Adam
37
+ true_wd: False
38
+ wd: 0.0
39
+ bn_wd: False
40
+ clip_grad: 20
41
+ lr: 0.0001
42
+ args: {
43
+ betas: !!python/tuple [0.9, 0.999], # for default Adam
44
+ }
45
+ scheduler: {
46
+ periods: [6, 4],
47
+ gamma: 0.1,
48
+ }
49
+
50
+ model:
51
+ name: 'modules.model_abinet_iter.ABINetIterModel'
52
+ iter_size: 3
53
+ ensemble: ''
54
+ use_vision: False
55
+ vision: {
56
+ checkpoint: workdir/pretrain-vision-model-sv/best-pretrain-vision-model-sv.pth,
57
+ loss_weight: 1.,
58
+ attention: 'attention',
59
+ backbone: 'transformer',
60
+ backbone_ln: 2,
61
+ }
62
+ language: {
63
+ checkpoint: workdir/pretrain-language-model/pretrain-language-model.pth,
64
+ num_layers: 4,
65
+ loss_weight: 1.,
66
+ detach: True,
67
+ use_self_attn: False
68
+ }
69
+ alignment: {
70
+ loss_weight: 1.,
71
+ }
configs/train_abinet_wo_iter.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ name: train-abinet-wo-iter
3
+ phase: train
4
+ stage: train-super
5
+ workdir: workdir
6
+ seed: ~
7
+
8
+ dataset:
9
+ train: {
10
+ roots: ['data/training/MJ/MJ_train/',
11
+ 'data/training/MJ/MJ_test/',
12
+ 'data/training/MJ/MJ_valid/',
13
+ 'data/training/ST'],
14
+ batch_size: 384
15
+ }
16
+ test: {
17
+ roots: ['data/evaluation/IIIT5k_3000',
18
+ 'data/evaluation/SVT',
19
+ 'data/evaluation/SVTP',
20
+ 'data/evaluation/IC13_857',
21
+ 'data/evaluation/IC15_1811',
22
+ 'data/evaluation/CUTE80'],
23
+ batch_size: 384
24
+ }
25
+ data_aug: True
26
+ multiscales: False
27
+ num_workers: 14
28
+
29
+ training:
30
+ epochs: 10
31
+ show_iters: 50
32
+ eval_iters: 3000
33
+ save_iters: 3000
34
+
35
+ optimizer:
36
+ type: Adam
37
+ true_wd: False
38
+ wd: 0.0
39
+ bn_wd: False
40
+ clip_grad: 20
41
+ lr: 0.0001
42
+ args: {
43
+ betas: !!python/tuple [0.9, 0.999], # for default Adam
44
+ }
45
+ scheduler: {
46
+ periods: [6, 4],
47
+ gamma: 0.1,
48
+ }
49
+
50
+ model:
51
+ name: 'modules.model_abinet.ABINetModel'
52
+ vision: {
53
+ checkpoint: workdir/pretrain-vision-model/best-pretrain-vision-model.pth,
54
+ loss_weight: 1.,
55
+ attention: 'position',
56
+ backbone: 'transformer',
57
+ backbone_ln: 3,
58
+ }
59
+ language: {
60
+ checkpoint: workdir/pretrain-language-model/pretrain-language-model.pth,
61
+ num_layers: 4,
62
+ loss_weight: 1.,
63
+ detach: True,
64
+ use_self_attn: False
65
+ }
66
+ alignment: {
67
+ loss_weight: 1.,
68
+ }
data/charset_36.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0 a
2
+ 1 b
3
+ 2 c
4
+ 3 d
5
+ 4 e
6
+ 5 f
7
+ 6 g
8
+ 7 h
9
+ 8 i
10
+ 9 j
11
+ 10 k
12
+ 11 l
13
+ 12 m
14
+ 13 n
15
+ 14 o
16
+ 15 p
17
+ 16 q
18
+ 17 r
19
+ 18 s
20
+ 19 t
21
+ 20 u
22
+ 21 v
23
+ 22 w
24
+ 23 x
25
+ 24 y
26
+ 25 z
27
+ 26 1
28
+ 27 2
29
+ 28 3
30
+ 29 4
31
+ 30 5
32
+ 31 6
33
+ 32 7
34
+ 33 8
35
+ 34 9
36
+ 35 0
data/charset_62.txt ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0 0
2
+ 1 1
3
+ 2 2
4
+ 3 3
5
+ 4 4
6
+ 5 5
7
+ 6 6
8
+ 7 7
9
+ 8 8
10
+ 9 9
11
+ 10 A
12
+ 11 B
13
+ 12 C
14
+ 13 D
15
+ 14 E
16
+ 15 F
17
+ 16 G
18
+ 17 H
19
+ 18 I
20
+ 19 J
21
+ 20 K
22
+ 21 L
23
+ 22 M
24
+ 23 N
25
+ 24 O
26
+ 25 P
27
+ 26 Q
28
+ 27 R
29
+ 28 S
30
+ 29 T
31
+ 30 U
32
+ 31 V
33
+ 32 W
34
+ 33 X
35
+ 34 Y
36
+ 35 Z
37
+ 36 a
38
+ 37 b
39
+ 38 c
40
+ 39 d
41
+ 40 e
42
+ 41 f
43
+ 42 g
44
+ 43 h
45
+ 44 i
46
+ 45 j
47
+ 46 k
48
+ 47 l
49
+ 48 m
50
+ 49 n
51
+ 50 o
52
+ 51 p
53
+ 52 q
54
+ 53 r
55
+ 54 s
56
+ 55 t
57
+ 56 u
58
+ 57 v
59
+ 58 w
60
+ 59 x
61
+ 60 y
62
+ 61 z
dataset.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+
4
+ import cv2
5
+ import lmdb
6
+ import six
7
+ from fastai.vision import *
8
+ from torchvision import transforms
9
+
10
+ from transforms import CVColorJitter, CVDeterioration, CVGeometry
11
+ from utils import CharsetMapper, onehot
12
+
13
+
14
+ class ImageDataset(Dataset):
15
+ "`ImageDataset` read data from LMDB database."
16
+
17
+ def __init__(self,
18
+ path:PathOrStr,
19
+ is_training:bool=True,
20
+ img_h:int=32,
21
+ img_w:int=100,
22
+ max_length:int=25,
23
+ check_length:bool=True,
24
+ case_sensitive:bool=False,
25
+ charset_path:str='data/charset_36.txt',
26
+ convert_mode:str='RGB',
27
+ data_aug:bool=True,
28
+ deteriorate_ratio:float=0.,
29
+ multiscales:bool=True,
30
+ one_hot_y:bool=True,
31
+ return_idx:bool=False,
32
+ return_raw:bool=False,
33
+ **kwargs):
34
+ self.path, self.name = Path(path), Path(path).name
35
+ assert self.path.is_dir() and self.path.exists(), f"{path} is not a valid directory."
36
+ self.convert_mode, self.check_length = convert_mode, check_length
37
+ self.img_h, self.img_w = img_h, img_w
38
+ self.max_length, self.one_hot_y = max_length, one_hot_y
39
+ self.return_idx, self.return_raw = return_idx, return_raw
40
+ self.case_sensitive, self.is_training = case_sensitive, is_training
41
+ self.data_aug, self.multiscales = data_aug, multiscales
42
+ self.charset = CharsetMapper(charset_path, max_length=max_length+1)
43
+ self.c = self.charset.num_classes
44
+
45
+ self.env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False)
46
+ assert self.env, f'Cannot open LMDB dataset from {path}.'
47
+ with self.env.begin(write=False) as txn:
48
+ self.length = int(txn.get('num-samples'.encode()))
49
+
50
+ if self.is_training and self.data_aug:
51
+ self.augment_tfs = transforms.Compose([
52
+ CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5),
53
+ CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
54
+ CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25)
55
+ ])
56
+ self.totensor = transforms.ToTensor()
57
+
58
+ def __len__(self): return self.length
59
+
60
+ def _next_image(self, index):
61
+ next_index = random.randint(0, len(self) - 1)
62
+ return self.get(next_index)
63
+
64
+ def _check_image(self, x, pixels=6):
65
+ if x.size[0] <= pixels or x.size[1] <= pixels: return False
66
+ else: return True
67
+
68
+ def resize_multiscales(self, img, borderType=cv2.BORDER_CONSTANT):
69
+ def _resize_ratio(img, ratio, fix_h=True):
70
+ if ratio * self.img_w < self.img_h:
71
+ if fix_h: trg_h = self.img_h
72
+ else: trg_h = int(ratio * self.img_w)
73
+ trg_w = self.img_w
74
+ else: trg_h, trg_w = self.img_h, int(self.img_h / ratio)
75
+ img = cv2.resize(img, (trg_w, trg_h))
76
+ pad_h, pad_w = (self.img_h - trg_h) / 2, (self.img_w - trg_w) / 2
77
+ top, bottom = math.ceil(pad_h), math.floor(pad_h)
78
+ left, right = math.ceil(pad_w), math.floor(pad_w)
79
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType)
80
+ return img
81
+
82
+ if self.is_training:
83
+ if random.random() < 0.5:
84
+ base, maxh, maxw = self.img_h, self.img_h, self.img_w
85
+ h, w = random.randint(base, maxh), random.randint(base, maxw)
86
+ return _resize_ratio(img, h/w)
87
+ else: return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio
88
+ else: return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio
89
+
90
+ def resize(self, img):
91
+ if self.multiscales: return self.resize_multiscales(img, cv2.BORDER_REPLICATE)
92
+ else: return cv2.resize(img, (self.img_w, self.img_h))
93
+
94
+ def get(self, idx):
95
+ with self.env.begin(write=False) as txn:
96
+ image_key, label_key = f'image-{idx+1:09d}', f'label-{idx+1:09d}'
97
+ try:
98
+ label = str(txn.get(label_key.encode()), 'utf-8') # label
99
+ label = re.sub('[^0-9a-zA-Z]+', '', label)
100
+ if self.check_length and self.max_length > 0:
101
+ if len(label) > self.max_length or len(label) <= 0:
102
+ #logging.info(f'Long or short text image is found: {self.name}, {idx}, {label}, {len(label)}')
103
+ return self._next_image(idx)
104
+ label = label[:self.max_length]
105
+
106
+ imgbuf = txn.get(image_key.encode()) # image
107
+ buf = six.BytesIO()
108
+ buf.write(imgbuf)
109
+ buf.seek(0)
110
+ with warnings.catch_warnings():
111
+ warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin
112
+ image = PIL.Image.open(buf).convert(self.convert_mode)
113
+ if self.is_training and not self._check_image(image):
114
+ #logging.info(f'Invalid image is found: {self.name}, {idx}, {label}, {len(label)}')
115
+ return self._next_image(idx)
116
+ except:
117
+ import traceback
118
+ traceback.print_exc()
119
+ logging.info(f'Corrupted image is found: {self.name}, {idx}, {label}, {len(label)}')
120
+ return self._next_image(idx)
121
+ return image, label, idx
122
+
123
+ def _process_training(self, image):
124
+ if self.data_aug: image = self.augment_tfs(image)
125
+ image = self.resize(np.array(image))
126
+ return image
127
+
128
+ def _process_test(self, image):
129
+ return self.resize(np.array(image)) # TODO:move is_training to here
130
+
131
+ def __getitem__(self, idx):
132
+ image, text, idx_new = self.get(idx)
133
+ if not self.is_training: assert idx == idx_new, f'idx {idx} != idx_new {idx_new} during testing.'
134
+
135
+ if self.is_training: image = self._process_training(image)
136
+ else: image = self._process_test(image)
137
+ if self.return_raw: return image, text
138
+ image = self.totensor(image)
139
+
140
+ length = tensor(len(text) + 1).to(dtype=torch.long) # one for end token
141
+ label = self.charset.get_labels(text, case_sensitive=self.case_sensitive)
142
+ label = tensor(label).to(dtype=torch.long)
143
+ if self.one_hot_y: label = onehot(label, self.charset.num_classes)
144
+
145
+ if self.return_idx: y = [label, length, idx_new]
146
+ else: y = [label, length]
147
+ return image, y
148
+
149
+
150
+ class TextDataset(Dataset):
151
+ def __init__(self,
152
+ path:PathOrStr,
153
+ delimiter:str='\t',
154
+ max_length:int=25,
155
+ charset_path:str='data/charset_36.txt',
156
+ case_sensitive=False,
157
+ one_hot_x=True,
158
+ one_hot_y=True,
159
+ is_training=True,
160
+ smooth_label=False,
161
+ smooth_factor=0.2,
162
+ use_sm=False,
163
+ **kwargs):
164
+ self.path = Path(path)
165
+ self.case_sensitive, self.use_sm = case_sensitive, use_sm
166
+ self.smooth_factor, self.smooth_label = smooth_factor, smooth_label
167
+ self.charset = CharsetMapper(charset_path, max_length=max_length+1)
168
+ self.one_hot_x, self.one_hot_y, self.is_training = one_hot_x, one_hot_y, is_training
169
+ if self.is_training and self.use_sm: self.sm = SpellingMutation(charset=self.charset)
170
+
171
+ dtype = {'inp': str, 'gt': str}
172
+ self.df = pd.read_csv(self.path, dtype=dtype, delimiter=delimiter, na_filter=False)
173
+ self.inp_col, self.gt_col = 0, 1
174
+
175
+ def __len__(self): return len(self.df)
176
+
177
+ def __getitem__(self, idx):
178
+ text_x = self.df.iloc[idx, self.inp_col]
179
+ text_x = re.sub('[^0-9a-zA-Z]+', '', text_x)
180
+ if not self.case_sensitive: text_x = text_x.lower()
181
+ if self.is_training and self.use_sm: text_x = self.sm(text_x)
182
+
183
+ length_x = tensor(len(text_x) + 1).to(dtype=torch.long) # one for end token
184
+ label_x = self.charset.get_labels(text_x, case_sensitive=self.case_sensitive)
185
+ label_x = tensor(label_x)
186
+ if self.one_hot_x:
187
+ label_x = onehot(label_x, self.charset.num_classes)
188
+ if self.is_training and self.smooth_label:
189
+ label_x = torch.stack([self.prob_smooth_label(l) for l in label_x])
190
+ x = [label_x, length_x]
191
+
192
+ text_y = self.df.iloc[idx, self.gt_col]
193
+ text_y = re.sub('[^0-9a-zA-Z]+', '', text_y)
194
+ if not self.case_sensitive: text_y = text_y.lower()
195
+ length_y = tensor(len(text_y) + 1).to(dtype=torch.long) # one for end token
196
+ label_y = self.charset.get_labels(text_y, case_sensitive=self.case_sensitive)
197
+ label_y = tensor(label_y)
198
+ if self.one_hot_y: label_y = onehot(label_y, self.charset.num_classes)
199
+ y = [label_y, length_y]
200
+
201
+ return x, y
202
+
203
+ def prob_smooth_label(self, one_hot):
204
+ one_hot = one_hot.float()
205
+ delta = torch.rand([]) * self.smooth_factor
206
+ num_classes = len(one_hot)
207
+ noise = torch.rand(num_classes)
208
+ noise = noise / noise.sum() * delta
209
+ one_hot = one_hot * (1 - delta) + noise
210
+ return one_hot
211
+
212
+
213
+ class SpellingMutation(object):
214
+ def __init__(self, pn0=0.7, pn1=0.85, pn2=0.95, pt0=0.7, pt1=0.85, charset=None):
215
+ """
216
+ Args:
217
+ pn0: the prob of not modifying characters is (pn0)
218
+ pn1: the prob of modifying one characters is (pn1 - pn0)
219
+ pn2: the prob of modifying two characters is (pn2 - pn1),
220
+ and three (1 - pn2)
221
+ pt0: the prob of replacing operation is pt0.
222
+ pt1: the prob of inserting operation is (pt1 - pt0),
223
+ and deleting operation is (1 - pt1)
224
+ """
225
+ super().__init__()
226
+ self.pn0, self.pn1, self.pn2 = pn0, pn1, pn2
227
+ self.pt0, self.pt1 = pt0, pt1
228
+ self.charset = charset
229
+ logging.info(f'the probs: pn0={self.pn0}, pn1={self.pn1} ' +
230
+ f'pn2={self.pn2}, pt0={self.pt0}, pt1={self.pt1}')
231
+
232
+ def is_digit(self, text, ratio=0.5):
233
+ length = max(len(text), 1)
234
+ digit_num = sum([t in self.charset.digits for t in text])
235
+ if digit_num / length < ratio: return False
236
+ return True
237
+
238
+ def is_unk_char(self, char):
239
+ # return char == self.charset.unk_char
240
+ return (char not in self.charset.digits) and (char not in self.charset.alphabets)
241
+
242
+ def get_num_to_modify(self, length):
243
+ prob = random.random()
244
+ if prob < self.pn0: num_to_modify = 0
245
+ elif prob < self.pn1: num_to_modify = 1
246
+ elif prob < self.pn2: num_to_modify = 2
247
+ else: num_to_modify = 3
248
+
249
+ if length <= 1: num_to_modify = 0
250
+ elif length >= 2 and length <= 4: num_to_modify = min(num_to_modify, 1)
251
+ else: num_to_modify = min(num_to_modify, length // 2) # smaller than length // 2
252
+ return num_to_modify
253
+
254
+ def __call__(self, text, debug=False):
255
+ if self.is_digit(text): return text
256
+ length = len(text)
257
+ num_to_modify = self.get_num_to_modify(length)
258
+ if num_to_modify <= 0: return text
259
+
260
+ chars = []
261
+ index = np.arange(0, length)
262
+ random.shuffle(index)
263
+ index = index[: num_to_modify]
264
+ if debug: self.index = index
265
+ for i, t in enumerate(text):
266
+ if i not in index: chars.append(t)
267
+ elif self.is_unk_char(t): chars.append(t)
268
+ else:
269
+ prob = random.random()
270
+ if prob < self.pt0: # replace
271
+ chars.append(random.choice(self.charset.alphabets))
272
+ elif prob < self.pt1: # insert
273
+ chars.append(random.choice(self.charset.alphabets))
274
+ chars.append(t)
275
+ else: # delete
276
+ continue
277
+ new_text = ''.join(chars[: self.charset.max_length-1])
278
+ return new_text if len(new_text) >= 1 else text
demo.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import glob
5
+ import tqdm
6
+ import torch
7
+ import PIL
8
+ import cv2
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+ from torchvision import transforms
12
+ from utils import Config, Logger, CharsetMapper
13
+
14
+ def get_model(config):
15
+ import importlib
16
+ names = config.model_name.split('.')
17
+ module_name, class_name = '.'.join(names[:-1]), names[-1]
18
+ cls = getattr(importlib.import_module(module_name), class_name)
19
+ model = cls(config)
20
+ logging.info(model)
21
+ model = model.eval()
22
+ return model
23
+
24
+ def preprocess(img, width, height):
25
+ img = cv2.resize(np.array(img), (width, height))
26
+ img = transforms.ToTensor()(img).unsqueeze(0)
27
+ mean = torch.tensor([0.485, 0.456, 0.406])
28
+ std = torch.tensor([0.229, 0.224, 0.225])
29
+ return (img-mean[...,None,None]) / std[...,None,None]
30
+
31
+ def postprocess(output, charset, model_eval):
32
+ def _get_output(last_output, model_eval):
33
+ if isinstance(last_output, (tuple, list)):
34
+ for res in last_output:
35
+ if res['name'] == model_eval: output = res
36
+ else: output = last_output
37
+ return output
38
+
39
+ def _decode(logit):
40
+ """ Greed decode """
41
+ out = F.softmax(logit, dim=2)
42
+ pt_text, pt_scores, pt_lengths = [], [], []
43
+ for o in out:
44
+ text = charset.get_text(o.argmax(dim=1), padding=False, trim=False)
45
+ text = text.split(charset.null_char)[0] # end at end-token
46
+ pt_text.append(text)
47
+ pt_scores.append(o.max(dim=1)[0])
48
+ pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token
49
+ return pt_text, pt_scores, pt_lengths
50
+
51
+ output = _get_output(output, model_eval)
52
+ logits, pt_lengths = output['logits'], output['pt_lengths']
53
+ pt_text, pt_scores, pt_lengths_ = _decode(logits)
54
+
55
+ return pt_text, pt_scores, pt_lengths_
56
+
57
+ def load(model, file, device=None, strict=True):
58
+ if device is None: device = 'cpu'
59
+ elif isinstance(device, int): device = torch.device('cuda', device)
60
+ assert os.path.isfile(file)
61
+ state = torch.load(file, map_location=device)
62
+ if set(state.keys()) == {'model', 'opt'}:
63
+ state = state['model']
64
+ model.load_state_dict(state, strict=strict)
65
+ return model
66
+
67
+ def main():
68
+ parser = argparse.ArgumentParser()
69
+ parser.add_argument('--config', type=str, default='configs/train_abinet.yaml',
70
+ help='path to config file')
71
+ parser.add_argument('--input', type=str, default='figs/test')
72
+ parser.add_argument('--cuda', type=int, default=-1)
73
+ parser.add_argument('--checkpoint', type=str, default='workdir/train-abinet/best-train-abinet.pth')
74
+ parser.add_argument('--model_eval', type=str, default='alignment',
75
+ choices=['alignment', 'vision', 'language'])
76
+ args = parser.parse_args()
77
+ config = Config(args.config)
78
+ if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
79
+ if args.model_eval is not None: config.model_eval = args.model_eval
80
+ config.global_phase = 'test'
81
+ config.model_vision_checkpoint, config.model_language_checkpoint = None, None
82
+ device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}'
83
+
84
+ Logger.init(config.global_workdir, config.global_name, config.global_phase)
85
+ Logger.enable_file()
86
+ logging.info(config)
87
+
88
+ logging.info('Construct model.')
89
+ model = get_model(config).to(device)
90
+ model = load(model, config.model_checkpoint, device=device)
91
+ charset = CharsetMapper(filename=config.dataset_charset_path,
92
+ max_length=config.dataset_max_length + 1)
93
+
94
+ if os.path.isdir(args.input):
95
+ paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)]
96
+ else:
97
+ paths = glob.glob(os.path.expanduser(args.input))
98
+ assert paths, "The input path(s) was not found"
99
+ paths = sorted(paths)
100
+ for path in tqdm.tqdm(paths):
101
+ img = PIL.Image.open(path).convert('RGB')
102
+ img = preprocess(img, config.dataset_image_width, config.dataset_image_height)
103
+ img = img.to(device)
104
+ res = model(img)
105
+ pt_text, _, __ = postprocess(res, charset, config.model_eval)
106
+ logging.info(f'{path}: {pt_text[0]}')
107
+
108
+ if __name__ == '__main__':
109
+ main()
docker/Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM anibali/pytorch:cuda-9.0
2
+ MAINTAINER fangshancheng <fangsc@ustc.edu.cn>
3
+ RUN sudo rm -rf /etc/apt/sources.list.d && \
4
+ sudo apt update && \
5
+ sudo apt install -y build-essential vim && \
6
+ conda config --add channels https://mirrors.ustc.edu.cn/anaconda/pkgs/free/ && \
7
+ conda config --add channels https://mirrors.ustc.edu.cn/anaconda/pkgs/main/ && \
8
+ conda config --set show_channel_urls yes && \
9
+ pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ && \
10
+ pip install torch==1.1.0 torchvision==0.3.0 && \
11
+ pip install fastai==1.0.60 && \
12
+ pip install ipdb jupyter ipython lmdb editdistance tensorboardX natsort nltk && \
13
+ conda uninstall -y --force pillow pil jpeg libtiff libjpeg-turbo && \
14
+ pip uninstall -y pillow pil jpeg libtiff libjpeg-turbo && \
15
+ conda install -yc conda-forge libjpeg-turbo && \
16
+ CFLAGS="${CFLAGS} -mavx2" pip install --no-cache-dir --force-reinstall --no-binary :all: --compile pillow-simd==6.2.2.post1 && \
17
+ conda install -y jpeg libtiff opencv && \
18
+ sudo rm -rf /var/lib/apt/lists/* && \
19
+ sudo rm -rf /tmp/* && \
20
+ sudo rm -rf ~/.cache && \
21
+ sudo apt clean all && \
22
+ conda clean -y -a
23
+ EXPOSE 8888
24
+ ENV LANG C.UTF-8
25
+ ENV LC_ALL C.UTF-8
figs/cases.png ADDED
figs/framework.png ADDED
figs/test/CANDY.png ADDED
figs/test/ESPLANADE.png ADDED
figs/test/GLOBE.png ADDED
figs/test/KAPPA.png ADDED
figs/test/MANDARIN.png ADDED
figs/test/MEETS.png ADDED
figs/test/MONTHLY.png ADDED
figs/test/RESTROOM.png ADDED
losses.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.vision import *
2
+
3
+ from modules.model import Model
4
+
5
+
6
+ class MultiLosses(nn.Module):
7
+ def __init__(self, one_hot=True):
8
+ super().__init__()
9
+ self.ce = SoftCrossEntropyLoss() if one_hot else torch.nn.CrossEntropyLoss()
10
+ self.bce = torch.nn.BCELoss()
11
+
12
+ @property
13
+ def last_losses(self):
14
+ return self.losses
15
+
16
+ def _flatten(self, sources, lengths):
17
+ return torch.cat([t[:l] for t, l in zip(sources, lengths)])
18
+
19
+ def _merge_list(self, all_res):
20
+ if not isinstance(all_res, (list, tuple)):
21
+ return all_res
22
+ def merge(items):
23
+ if isinstance(items[0], torch.Tensor): return torch.cat(items, dim=0)
24
+ else: return items[0]
25
+ res = dict()
26
+ for key in all_res[0].keys():
27
+ items = [r[key] for r in all_res]
28
+ res[key] = merge(items)
29
+ return res
30
+
31
+ def _ce_loss(self, output, gt_labels, gt_lengths, idx=None, record=True):
32
+ loss_name = output.get('name')
33
+ pt_logits, weight = output['logits'], output['loss_weight']
34
+
35
+ assert pt_logits.shape[0] % gt_labels.shape[0] == 0
36
+ iter_size = pt_logits.shape[0] // gt_labels.shape[0]
37
+ if iter_size > 1:
38
+ gt_labels = gt_labels.repeat(3, 1, 1)
39
+ gt_lengths = gt_lengths.repeat(3)
40
+ flat_gt_labels = self._flatten(gt_labels, gt_lengths)
41
+ flat_pt_logits = self._flatten(pt_logits, gt_lengths)
42
+
43
+ nll = output.get('nll')
44
+ if nll is not None:
45
+ loss = self.ce(flat_pt_logits, flat_gt_labels, softmax=False) * weight
46
+ else:
47
+ loss = self.ce(flat_pt_logits, flat_gt_labels) * weight
48
+ if record and loss_name is not None: self.losses[f'{loss_name}_loss'] = loss
49
+
50
+ return loss
51
+
52
+ def forward(self, outputs, *args):
53
+ self.losses = {}
54
+ if isinstance(outputs, (tuple, list)):
55
+ outputs = [self._merge_list(o) for o in outputs]
56
+ return sum([self._ce_loss(o, *args) for o in outputs if o['loss_weight'] > 0.])
57
+ else:
58
+ return self._ce_loss(outputs, *args, record=False)
59
+
60
+
61
+ class SoftCrossEntropyLoss(nn.Module):
62
+ def __init__(self, reduction="mean"):
63
+ super().__init__()
64
+ self.reduction = reduction
65
+
66
+ def forward(self, input, target, softmax=True):
67
+ if softmax: log_prob = F.log_softmax(input, dim=-1)
68
+ else: log_prob = torch.log(input)
69
+ loss = -(target * log_prob).sum(dim=-1)
70
+ if self.reduction == "mean": return loss.mean()
71
+ elif self.reduction == "sum": return loss.sum()
72
+ else: return loss
main.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import random
5
+
6
+ import torch
7
+ from fastai.callbacks.general_sched import GeneralScheduler, TrainingPhase
8
+ from fastai.distributed import *
9
+ from fastai.vision import *
10
+ from torch.backends import cudnn
11
+
12
+ from callbacks import DumpPrediction, IterationCallback, TextAccuracy, TopKTextAccuracy
13
+ from dataset import ImageDataset, TextDataset
14
+ from losses import MultiLosses
15
+ from utils import Config, Logger, MyDataParallel, MyConcatDataset
16
+
17
+
18
+ def _set_random_seed(seed):
19
+ if seed is not None:
20
+ random.seed(seed)
21
+ torch.manual_seed(seed)
22
+ cudnn.deterministic = True
23
+ logging.warning('You have chosen to seed training. '
24
+ 'This will slow down your training!')
25
+
26
+ def _get_training_phases(config, n):
27
+ lr = np.array(config.optimizer_lr)
28
+ periods = config.optimizer_scheduler_periods
29
+ sigma = [config.optimizer_scheduler_gamma ** i for i in range(len(periods))]
30
+ phases = [TrainingPhase(n * periods[i]).schedule_hp('lr', lr * sigma[i])
31
+ for i in range(len(periods))]
32
+ return phases
33
+
34
+ def _get_dataset(ds_type, paths, is_training, config, **kwargs):
35
+ kwargs.update({
36
+ 'img_h': config.dataset_image_height,
37
+ 'img_w': config.dataset_image_width,
38
+ 'max_length': config.dataset_max_length,
39
+ 'case_sensitive': config.dataset_case_sensitive,
40
+ 'charset_path': config.dataset_charset_path,
41
+ 'data_aug': config.dataset_data_aug,
42
+ 'deteriorate_ratio': config.dataset_deteriorate_ratio,
43
+ 'is_training': is_training,
44
+ 'multiscales': config.dataset_multiscales,
45
+ 'one_hot_y': config.dataset_one_hot_y,
46
+ })
47
+ datasets = [ds_type(p, **kwargs) for p in paths]
48
+ if len(datasets) > 1: return MyConcatDataset(datasets)
49
+ else: return datasets[0]
50
+
51
+
52
+ def _get_language_databaunch(config):
53
+ kwargs = {
54
+ 'max_length': config.dataset_max_length,
55
+ 'case_sensitive': config.dataset_case_sensitive,
56
+ 'charset_path': config.dataset_charset_path,
57
+ 'smooth_label': config.dataset_smooth_label,
58
+ 'smooth_factor': config.dataset_smooth_factor,
59
+ 'one_hot_y': config.dataset_one_hot_y,
60
+ 'use_sm': config.dataset_use_sm,
61
+ }
62
+ train_ds = TextDataset(config.dataset_train_roots[0], is_training=True, **kwargs)
63
+ valid_ds = TextDataset(config.dataset_test_roots[0], is_training=False, **kwargs)
64
+ data = DataBunch.create(
65
+ path=train_ds.path,
66
+ train_ds=train_ds,
67
+ valid_ds=valid_ds,
68
+ bs=config.dataset_train_batch_size,
69
+ val_bs=config.dataset_test_batch_size,
70
+ num_workers=config.dataset_num_workers,
71
+ pin_memory=config.dataset_pin_memory)
72
+ logging.info(f'{len(data.train_ds)} training items found.')
73
+ if not data.empty_val:
74
+ logging.info(f'{len(data.valid_ds)} valid items found.')
75
+ return data
76
+
77
+ def _get_databaunch(config):
78
+ # An awkward way to reduce loadding data time during test
79
+ if config.global_phase == 'test': config.dataset_train_roots = config.dataset_test_roots
80
+ train_ds = _get_dataset(ImageDataset, config.dataset_train_roots, True, config)
81
+ valid_ds = _get_dataset(ImageDataset, config.dataset_test_roots, False, config)
82
+ data = ImageDataBunch.create(
83
+ train_ds=train_ds,
84
+ valid_ds=valid_ds,
85
+ bs=config.dataset_train_batch_size,
86
+ val_bs=config.dataset_test_batch_size,
87
+ num_workers=config.dataset_num_workers,
88
+ pin_memory=config.dataset_pin_memory).normalize(imagenet_stats)
89
+ ar_tfm = lambda x: ((x[0], x[1]), x[1]) # auto-regression only for dtd
90
+ data.add_tfm(ar_tfm)
91
+
92
+ logging.info(f'{len(data.train_ds)} training items found.')
93
+ if not data.empty_val:
94
+ logging.info(f'{len(data.valid_ds)} valid items found.')
95
+
96
+ return data
97
+
98
+ def _get_model(config):
99
+ import importlib
100
+ names = config.model_name.split('.')
101
+ module_name, class_name = '.'.join(names[:-1]), names[-1]
102
+ cls = getattr(importlib.import_module(module_name), class_name)
103
+ model = cls(config)
104
+ logging.info(model)
105
+ return model
106
+
107
+
108
+ def _get_learner(config, data, model, local_rank=None):
109
+ strict = ifnone(config.model_strict, True)
110
+ if config.global_stage == 'pretrain-language':
111
+ metrics = [TopKTextAccuracy(
112
+ k=ifnone(config.model_k, 5),
113
+ charset_path=config.dataset_charset_path,
114
+ max_length=config.dataset_max_length + 1,
115
+ case_sensitive=config.dataset_eval_case_sensisitves,
116
+ model_eval=config.model_eval)]
117
+ else:
118
+ metrics = [TextAccuracy(
119
+ charset_path=config.dataset_charset_path,
120
+ max_length=config.dataset_max_length + 1,
121
+ case_sensitive=config.dataset_eval_case_sensisitves,
122
+ model_eval=config.model_eval)]
123
+ opt_type = getattr(torch.optim, config.optimizer_type)
124
+ learner = Learner(data, model, silent=True, model_dir='.',
125
+ true_wd=config.optimizer_true_wd,
126
+ wd=config.optimizer_wd,
127
+ bn_wd=config.optimizer_bn_wd,
128
+ path=config.global_workdir,
129
+ metrics=metrics,
130
+ opt_func=partial(opt_type, **config.optimizer_args or dict()),
131
+ loss_func=MultiLosses(one_hot=config.dataset_one_hot_y))
132
+ learner.split(lambda m: children(m))
133
+
134
+ if config.global_phase == 'train':
135
+ num_replicas = 1 if local_rank is None else torch.distributed.get_world_size()
136
+ phases = _get_training_phases(config, len(learner.data.train_dl)//num_replicas)
137
+ learner.callback_fns += [
138
+ partial(GeneralScheduler, phases=phases),
139
+ partial(GradientClipping, clip=config.optimizer_clip_grad),
140
+ partial(IterationCallback, name=config.global_name,
141
+ show_iters=config.training_show_iters,
142
+ eval_iters=config.training_eval_iters,
143
+ save_iters=config.training_save_iters,
144
+ start_iters=config.training_start_iters,
145
+ stats_iters=config.training_stats_iters)]
146
+ else:
147
+ learner.callbacks += [
148
+ DumpPrediction(learn=learner,
149
+ dataset='-'.join([Path(p).name for p in config.dataset_test_roots]),charset_path=config.dataset_charset_path,
150
+ model_eval=config.model_eval,
151
+ debug=config.global_debug,
152
+ image_only=config.global_image_only)]
153
+
154
+ learner.rank = local_rank
155
+ if local_rank is not None:
156
+ logging.info(f'Set model to distributed with rank {local_rank}.')
157
+ learner.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(learner.model)
158
+ learner.model.to(local_rank)
159
+ learner = learner.to_distributed(local_rank)
160
+
161
+ if torch.cuda.device_count() > 1 and local_rank is None:
162
+ logging.info(f'Use {torch.cuda.device_count()} GPUs.')
163
+ learner.model = MyDataParallel(learner.model)
164
+
165
+ if config.model_checkpoint:
166
+ if Path(config.model_checkpoint).exists():
167
+ with open(config.model_checkpoint, 'rb') as f:
168
+ buffer = io.BytesIO(f.read())
169
+ learner.load(buffer, strict=strict)
170
+ else:
171
+ from distutils.dir_util import copy_tree
172
+ src = Path('/data/fangsc/model')/config.global_name
173
+ trg = Path('/output')/config.global_name
174
+ if src.exists(): copy_tree(str(src), str(trg))
175
+ learner.load(config.model_checkpoint, strict=strict)
176
+ logging.info(f'Read model from {config.model_checkpoint}')
177
+ elif config.global_phase == 'test':
178
+ learner.load(f'best-{config.global_name}', strict=strict)
179
+ logging.info(f'Read model from best-{config.global_name}')
180
+
181
+ if learner.opt_func.func.__name__ == 'Adadelta': # fastai bug, fix after 1.0.60
182
+ learner.fit(epochs=0, lr=config.optimizer_lr)
183
+ learner.opt.mom = 0.
184
+
185
+ return learner
186
+
187
+ def main():
188
+ parser = argparse.ArgumentParser()
189
+ parser.add_argument('--config', type=str, required=True,
190
+ help='path to config file')
191
+ parser.add_argument('--phase', type=str, default=None, choices=['train', 'test'])
192
+ parser.add_argument('--name', type=str, default=None)
193
+ parser.add_argument('--checkpoint', type=str, default=None)
194
+ parser.add_argument('--test_root', type=str, default=None)
195
+ parser.add_argument("--local_rank", type=int, default=None)
196
+ parser.add_argument('--debug', action='store_true', default=None)
197
+ parser.add_argument('--image_only', action='store_true', default=None)
198
+ parser.add_argument('--model_strict', action='store_false', default=None)
199
+ parser.add_argument('--model_eval', type=str, default=None,
200
+ choices=['alignment', 'vision', 'language'])
201
+ args = parser.parse_args()
202
+ config = Config(args.config)
203
+ if args.name is not None: config.global_name = args.name
204
+ if args.phase is not None: config.global_phase = args.phase
205
+ if args.test_root is not None: config.dataset_test_roots = [args.test_root]
206
+ if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
207
+ if args.debug is not None: config.global_debug = args.debug
208
+ if args.image_only is not None: config.global_image_only = args.image_only
209
+ if args.model_eval is not None: config.model_eval = args.model_eval
210
+ if args.model_strict is not None: config.model_strict = args.model_strict
211
+
212
+ Logger.init(config.global_workdir, config.global_name, config.global_phase)
213
+ Logger.enable_file()
214
+ _set_random_seed(config.global_seed)
215
+ logging.info(config)
216
+
217
+ if args.local_rank is not None:
218
+ logging.info(f'Init distribution training at device {args.local_rank}.')
219
+ torch.cuda.set_device(args.local_rank)
220
+ torch.distributed.init_process_group(backend='nccl', init_method='env://')
221
+
222
+ logging.info('Construct dataset.')
223
+ if config.global_stage == 'pretrain-language': data = _get_language_databaunch(config)
224
+ else: data = _get_databaunch(config)
225
+
226
+ logging.info('Construct model.')
227
+ model = _get_model(config)
228
+
229
+ logging.info('Construct learner.')
230
+ learner = _get_learner(config, data, model, args.local_rank)
231
+
232
+ if config.global_phase == 'train':
233
+ logging.info('Start training.')
234
+ learner.fit(epochs=config.training_epochs,
235
+ lr=config.optimizer_lr)
236
+ else:
237
+ logging.info('Start validate')
238
+ last_metrics = learner.validate()
239
+ log_str = f'eval loss = {last_metrics[0]:6.3f}, ' \
240
+ f'ccr = {last_metrics[1]:6.3f}, cwr = {last_metrics[2]:6.3f}, ' \
241
+ f'ted = {last_metrics[3]:6.3f}, ned = {last_metrics[4]:6.0f}, ' \
242
+ f'ted/w = {last_metrics[5]:6.3f}, '
243
+ logging.info(log_str)
244
+
245
+ if __name__ == '__main__':
246
+ main()
modules/__init__.py ADDED
File without changes
modules/attention.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .transformer import PositionalEncoding
4
+
5
+ class Attention(nn.Module):
6
+ def __init__(self, in_channels=512, max_length=25, n_feature=256):
7
+ super().__init__()
8
+ self.max_length = max_length
9
+
10
+ self.f0_embedding = nn.Embedding(max_length, in_channels)
11
+ self.w0 = nn.Linear(max_length, n_feature)
12
+ self.wv = nn.Linear(in_channels, in_channels)
13
+ self.we = nn.Linear(in_channels, max_length)
14
+
15
+ self.active = nn.Tanh()
16
+ self.softmax = nn.Softmax(dim=2)
17
+
18
+ def forward(self, enc_output):
19
+ enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2)
20
+ reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device)
21
+ reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S)
22
+ reading_order_embed = self.f0_embedding(reading_order) # b,25,512
23
+
24
+ t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256
25
+ t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512
26
+
27
+ attn = self.we(t) # b,256,25
28
+ attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256
29
+ g_output = torch.bmm(attn, enc_output) # b,25,512
30
+ return g_output, attn.view(*attn.shape[:2], 8, 32)
31
+
32
+
33
+ def encoder_layer(in_c, out_c, k=3, s=2, p=1):
34
+ return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
35
+ nn.BatchNorm2d(out_c),
36
+ nn.ReLU(True))
37
+
38
+ def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None):
39
+ align_corners = None if mode=='nearest' else True
40
+ return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor,
41
+ mode=mode, align_corners=align_corners),
42
+ nn.Conv2d(in_c, out_c, k, s, p),
43
+ nn.BatchNorm2d(out_c),
44
+ nn.ReLU(True))
45
+
46
+
47
+ class PositionAttention(nn.Module):
48
+ def __init__(self, max_length, in_channels=512, num_channels=64,
49
+ h=8, w=32, mode='nearest', **kwargs):
50
+ super().__init__()
51
+ self.max_length = max_length
52
+ self.k_encoder = nn.Sequential(
53
+ encoder_layer(in_channels, num_channels, s=(1, 2)),
54
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
55
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
56
+ encoder_layer(num_channels, num_channels, s=(2, 2))
57
+ )
58
+ self.k_decoder = nn.Sequential(
59
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
60
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
61
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
62
+ decoder_layer(num_channels, in_channels, size=(h, w), mode=mode)
63
+ )
64
+
65
+ self.pos_encoder = PositionalEncoding(in_channels, dropout=0, max_len=max_length)
66
+ self.project = nn.Linear(in_channels, in_channels)
67
+
68
+ def forward(self, x):
69
+ N, E, H, W = x.size()
70
+ k, v = x, x # (N, E, H, W)
71
+
72
+ # calculate key vector
73
+ features = []
74
+ for i in range(0, len(self.k_encoder)):
75
+ k = self.k_encoder[i](k)
76
+ features.append(k)
77
+ for i in range(0, len(self.k_decoder) - 1):
78
+ k = self.k_decoder[i](k)
79
+ k = k + features[len(self.k_decoder) - 2 - i]
80
+ k = self.k_decoder[-1](k)
81
+
82
+ # calculate query vector
83
+ # TODO q=f(q,k)
84
+ zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E)
85
+ q = self.pos_encoder(zeros) # (T, N, E)
86
+ q = q.permute(1, 0, 2) # (N, T, E)
87
+ q = self.project(q) # (N, T, E)
88
+
89
+ # calculate attention
90
+ attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
91
+ attn_scores = attn_scores / (E ** 0.5)
92
+ attn_scores = torch.softmax(attn_scores, dim=-1)
93
+
94
+ v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
95
+ attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
96
+
97
+ return attn_vecs, attn_scores.view(N, -1, H, W)
modules/backbone.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from modules.model import _default_tfmer_cfg
6
+ from modules.resnet import resnet45
7
+ from modules.transformer import (PositionalEncoding,
8
+ TransformerEncoder,
9
+ TransformerEncoderLayer)
10
+
11
+
12
+ class ResTranformer(nn.Module):
13
+ def __init__(self, config):
14
+ super().__init__()
15
+ self.resnet = resnet45()
16
+
17
+ self.d_model = ifnone(config.model_vision_d_model, _default_tfmer_cfg['d_model'])
18
+ nhead = ifnone(config.model_vision_nhead, _default_tfmer_cfg['nhead'])
19
+ d_inner = ifnone(config.model_vision_d_inner, _default_tfmer_cfg['d_inner'])
20
+ dropout = ifnone(config.model_vision_dropout, _default_tfmer_cfg['dropout'])
21
+ activation = ifnone(config.model_vision_activation, _default_tfmer_cfg['activation'])
22
+ num_layers = ifnone(config.model_vision_backbone_ln, 2)
23
+
24
+ self.pos_encoder = PositionalEncoding(self.d_model, max_len=8*32)
25
+ encoder_layer = TransformerEncoderLayer(d_model=self.d_model, nhead=nhead,
26
+ dim_feedforward=d_inner, dropout=dropout, activation=activation)
27
+ self.transformer = TransformerEncoder(encoder_layer, num_layers)
28
+
29
+ def forward(self, images):
30
+ feature = self.resnet(images)
31
+ n, c, h, w = feature.shape
32
+ feature = feature.view(n, c, -1).permute(2, 0, 1)
33
+ feature = self.pos_encoder(feature)
34
+ feature = self.transformer(feature)
35
+ feature = feature.permute(1, 2, 0).view(n, c, h, w)
36
+ return feature
modules/model.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from utils import CharsetMapper
5
+
6
+
7
+ _default_tfmer_cfg = dict(d_model=512, nhead=8, d_inner=2048, # 1024
8
+ dropout=0.1, activation='relu')
9
+
10
+ class Model(nn.Module):
11
+
12
+ def __init__(self, config):
13
+ super().__init__()
14
+ self.max_length = config.dataset_max_length + 1
15
+ self.charset = CharsetMapper(config.dataset_charset_path, max_length=self.max_length)
16
+
17
+ def load(self, source, device=None, strict=True):
18
+ state = torch.load(source, map_location=device)
19
+ self.load_state_dict(state['model'], strict=strict)
20
+
21
+ def _get_length(self, logit, dim=-1):
22
+ """ Greed decoder to obtain length from logit"""
23
+ out = (logit.argmax(dim=-1) == self.charset.null_label)
24
+ abn = out.any(dim)
25
+ out = ((out.cumsum(dim) == 1) & out).max(dim)[1]
26
+ out = out + 1 # additional end token
27
+ out = torch.where(abn, out, out.new_tensor(logit.shape[1]))
28
+ return out
29
+
30
+ @staticmethod
31
+ def _get_padding_mask(length, max_length):
32
+ length = length.unsqueeze(-1)
33
+ grid = torch.arange(0, max_length, device=length.device).unsqueeze(0)
34
+ return grid >= length
35
+
36
+ @staticmethod
37
+ def _get_square_subsequent_mask(sz, device, diagonal=0, fw=True):
38
+ r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
39
+ Unmasked positions are filled with float(0.0).
40
+ """
41
+ mask = (torch.triu(torch.ones(sz, sz, device=device), diagonal=diagonal) == 1)
42
+ if fw: mask = mask.transpose(0, 1)
43
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
44
+ return mask
45
+
46
+ @staticmethod
47
+ def _get_location_mask(sz, device=None):
48
+ mask = torch.eye(sz, device=device)
49
+ mask = mask.float().masked_fill(mask == 1, float('-inf'))
50
+ return mask
modules/model_abinet.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from .model_vision import BaseVision
6
+ from .model_language import BCNLanguage
7
+ from .model_alignment import BaseAlignment
8
+
9
+
10
+ class ABINetModel(nn.Module):
11
+ def __init__(self, config):
12
+ super().__init__()
13
+ self.use_alignment = ifnone(config.model_use_alignment, True)
14
+ self.max_length = config.dataset_max_length + 1 # additional stop token
15
+ self.vision = BaseVision(config)
16
+ self.language = BCNLanguage(config)
17
+ if self.use_alignment: self.alignment = BaseAlignment(config)
18
+
19
+ def forward(self, images, *args):
20
+ v_res = self.vision(images)
21
+ v_tokens = torch.softmax(v_res['logits'], dim=-1)
22
+ v_lengths = v_res['pt_lengths'].clamp_(2, self.max_length) # TODO:move to langauge model
23
+
24
+ l_res = self.language(v_tokens, v_lengths)
25
+ if not self.use_alignment:
26
+ return l_res, v_res
27
+ l_feature, v_feature = l_res['feature'], v_res['feature']
28
+
29
+ a_res = self.alignment(l_feature, v_feature)
30
+ return a_res, l_res, v_res
modules/model_abinet_iter.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from .model_vision import BaseVision
6
+ from .model_language import BCNLanguage
7
+ from .model_alignment import BaseAlignment
8
+
9
+
10
+ class ABINetIterModel(nn.Module):
11
+ def __init__(self, config):
12
+ super().__init__()
13
+ self.iter_size = ifnone(config.model_iter_size, 1)
14
+ self.max_length = config.dataset_max_length + 1 # additional stop token
15
+ self.vision = BaseVision(config)
16
+ self.language = BCNLanguage(config)
17
+ self.alignment = BaseAlignment(config)
18
+
19
+ def forward(self, images, *args):
20
+ v_res = self.vision(images)
21
+ a_res = v_res
22
+ all_l_res, all_a_res = [], []
23
+ for _ in range(self.iter_size):
24
+ tokens = torch.softmax(a_res['logits'], dim=-1)
25
+ lengths = a_res['pt_lengths']
26
+ lengths.clamp_(2, self.max_length) # TODO:move to langauge model
27
+ l_res = self.language(tokens, lengths)
28
+ all_l_res.append(l_res)
29
+ a_res = self.alignment(l_res['feature'], v_res['feature'])
30
+ all_a_res.append(a_res)
31
+ if self.training:
32
+ return all_a_res, all_l_res, v_res
33
+ else:
34
+ return a_res, all_l_res[-1], v_res
modules/model_alignment.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from modules.model import Model, _default_tfmer_cfg
6
+
7
+
8
+ class BaseAlignment(Model):
9
+ def __init__(self, config):
10
+ super().__init__(config)
11
+ d_model = ifnone(config.model_alignment_d_model, _default_tfmer_cfg['d_model'])
12
+
13
+ self.loss_weight = ifnone(config.model_alignment_loss_weight, 1.0)
14
+ self.max_length = config.dataset_max_length + 1 # additional stop token
15
+ self.w_att = nn.Linear(2 * d_model, d_model)
16
+ self.cls = nn.Linear(d_model, self.charset.num_classes)
17
+
18
+ def forward(self, l_feature, v_feature):
19
+ """
20
+ Args:
21
+ l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
22
+ v_feature: (N, T, E) shape the same as l_feature
23
+ l_lengths: (N,)
24
+ v_lengths: (N,)
25
+ """
26
+ f = torch.cat((l_feature, v_feature), dim=2)
27
+ f_att = torch.sigmoid(self.w_att(f))
28
+ output = f_att * v_feature + (1 - f_att) * l_feature
29
+
30
+ logits = self.cls(output) # (N, T, C)
31
+ pt_lengths = self._get_length(logits)
32
+
33
+ return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight':self.loss_weight,
34
+ 'name': 'alignment'}
modules/model_language.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from modules.model import _default_tfmer_cfg
6
+ from modules.model import Model
7
+ from modules.transformer import (PositionalEncoding,
8
+ TransformerDecoder,
9
+ TransformerDecoderLayer)
10
+
11
+
12
+ class BCNLanguage(Model):
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+ d_model = ifnone(config.model_language_d_model, _default_tfmer_cfg['d_model'])
16
+ nhead = ifnone(config.model_language_nhead, _default_tfmer_cfg['nhead'])
17
+ d_inner = ifnone(config.model_language_d_inner, _default_tfmer_cfg['d_inner'])
18
+ dropout = ifnone(config.model_language_dropout, _default_tfmer_cfg['dropout'])
19
+ activation = ifnone(config.model_language_activation, _default_tfmer_cfg['activation'])
20
+ num_layers = ifnone(config.model_language_num_layers, 4)
21
+ self.d_model = d_model
22
+ self.detach = ifnone(config.model_language_detach, True)
23
+ self.use_self_attn = ifnone(config.model_language_use_self_attn, False)
24
+ self.loss_weight = ifnone(config.model_language_loss_weight, 1.0)
25
+ self.max_length = config.dataset_max_length + 1 # additional stop token
26
+ self.debug = ifnone(config.global_debug, False)
27
+
28
+ self.proj = nn.Linear(self.charset.num_classes, d_model, False)
29
+ self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length)
30
+ self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length)
31
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout,
32
+ activation, self_attn=self.use_self_attn, debug=self.debug)
33
+ self.model = TransformerDecoder(decoder_layer, num_layers)
34
+
35
+ self.cls = nn.Linear(d_model, self.charset.num_classes)
36
+
37
+ if config.model_language_checkpoint is not None:
38
+ logging.info(f'Read language model from {config.model_language_checkpoint}.')
39
+ self.load(config.model_language_checkpoint)
40
+
41
+ def forward(self, tokens, lengths):
42
+ """
43
+ Args:
44
+ tokens: (N, T, C) where T is length, N is batch size and C is classes number
45
+ lengths: (N,)
46
+ """
47
+ if self.detach: tokens = tokens.detach()
48
+ embed = self.proj(tokens) # (N, T, E)
49
+ embed = embed.permute(1, 0, 2) # (T, N, E)
50
+ embed = self.token_encoder(embed) # (T, N, E)
51
+ padding_mask = self._get_padding_mask(lengths, self.max_length)
52
+
53
+ zeros = embed.new_zeros(*embed.shape)
54
+ qeury = self.pos_encoder(zeros)
55
+ location_mask = self._get_location_mask(self.max_length, tokens.device)
56
+ output = self.model(qeury, embed,
57
+ tgt_key_padding_mask=padding_mask,
58
+ memory_mask=location_mask,
59
+ memory_key_padding_mask=padding_mask) # (T, N, E)
60
+ output = output.permute(1, 0, 2) # (N, T, E)
61
+
62
+ logits = self.cls(output) # (N, T, C)
63
+ pt_lengths = self._get_length(logits)
64
+
65
+ res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths,
66
+ 'loss_weight':self.loss_weight, 'name': 'language'}
67
+ return res
modules/model_vision.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from modules.attention import *
6
+ from modules.backbone import ResTranformer
7
+ from modules.model import Model
8
+ from modules.resnet import resnet45
9
+
10
+
11
+ class BaseVision(Model):
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+ self.loss_weight = ifnone(config.model_vision_loss_weight, 1.0)
15
+ self.out_channels = ifnone(config.model_vision_d_model, 512)
16
+
17
+ if config.model_vision_backbone == 'transformer':
18
+ self.backbone = ResTranformer(config)
19
+ else: self.backbone = resnet45()
20
+
21
+ if config.model_vision_attention == 'position':
22
+ mode = ifnone(config.model_vision_attention_mode, 'nearest')
23
+ self.attention = PositionAttention(
24
+ max_length=config.dataset_max_length + 1, # additional stop token
25
+ mode=mode,
26
+ )
27
+ elif config.model_vision_attention == 'attention':
28
+ self.attention = Attention(
29
+ max_length=config.dataset_max_length + 1, # additional stop token
30
+ n_feature=8*32,
31
+ )
32
+ else:
33
+ raise Exception(f'{config.model_vision_attention} is not valid.')
34
+ self.cls = nn.Linear(self.out_channels, self.charset.num_classes)
35
+
36
+ if config.model_vision_checkpoint is not None:
37
+ logging.info(f'Read vision model from {config.model_vision_checkpoint}.')
38
+ self.load(config.model_vision_checkpoint)
39
+
40
+ def forward(self, images, *args):
41
+ features = self.backbone(images) # (N, E, H, W)
42
+ attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W)
43
+ logits = self.cls(attn_vecs) # (N, T, C)
44
+ pt_lengths = self._get_length(logits)
45
+
46
+ return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths,
47
+ 'attn_scores': attn_scores, 'loss_weight':self.loss_weight, 'name': 'vision'}
modules/resnet.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.model_zoo as model_zoo
6
+
7
+
8
+ def conv1x1(in_planes, out_planes, stride=1):
9
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
10
+
11
+
12
+ def conv3x3(in_planes, out_planes, stride=1):
13
+ "3x3 convolution with padding"
14
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
15
+ padding=1, bias=False)
16
+
17
+
18
+ class BasicBlock(nn.Module):
19
+ expansion = 1
20
+
21
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv1x1(inplanes, planes)
24
+ self.bn1 = nn.BatchNorm2d(planes)
25
+ self.relu = nn.ReLU(inplace=True)
26
+ self.conv2 = conv3x3(planes, planes, stride)
27
+ self.bn2 = nn.BatchNorm2d(planes)
28
+ self.downsample = downsample
29
+ self.stride = stride
30
+
31
+ def forward(self, x):
32
+ residual = x
33
+
34
+ out = self.conv1(x)
35
+ out = self.bn1(out)
36
+ out = self.relu(out)
37
+
38
+ out = self.conv2(out)
39
+ out = self.bn2(out)
40
+
41
+ if self.downsample is not None:
42
+ residual = self.downsample(x)
43
+
44
+ out += residual
45
+ out = self.relu(out)
46
+
47
+ return out
48
+
49
+
50
+ class ResNet(nn.Module):
51
+
52
+ def __init__(self, block, layers):
53
+ self.inplanes = 32
54
+ super(ResNet, self).__init__()
55
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1,
56
+ bias=False)
57
+ self.bn1 = nn.BatchNorm2d(32)
58
+ self.relu = nn.ReLU(inplace=True)
59
+
60
+ self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
61
+ self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
62
+ self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
63
+ self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
64
+ self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
65
+
66
+ for m in self.modules():
67
+ if isinstance(m, nn.Conv2d):
68
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
69
+ m.weight.data.normal_(0, math.sqrt(2. / n))
70
+ elif isinstance(m, nn.BatchNorm2d):
71
+ m.weight.data.fill_(1)
72
+ m.bias.data.zero_()
73
+
74
+ def _make_layer(self, block, planes, blocks, stride=1):
75
+ downsample = None
76
+ if stride != 1 or self.inplanes != planes * block.expansion:
77
+ downsample = nn.Sequential(
78
+ nn.Conv2d(self.inplanes, planes * block.expansion,
79
+ kernel_size=1, stride=stride, bias=False),
80
+ nn.BatchNorm2d(planes * block.expansion),
81
+ )
82
+
83
+ layers = []
84
+ layers.append(block(self.inplanes, planes, stride, downsample))
85
+ self.inplanes = planes * block.expansion
86
+ for i in range(1, blocks):
87
+ layers.append(block(self.inplanes, planes))
88
+
89
+ return nn.Sequential(*layers)
90
+
91
+ def forward(self, x):
92
+ x = self.conv1(x)
93
+ x = self.bn1(x)
94
+ x = self.relu(x)
95
+ x = self.layer1(x)
96
+ x = self.layer2(x)
97
+ x = self.layer3(x)
98
+ x = self.layer4(x)
99
+ x = self.layer5(x)
100
+ return x
101
+
102
+
103
+ def resnet45():
104
+ return ResNet(BasicBlock, [3, 4, 6, 6, 3])
modules/transformer.py ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch 1.5.0
2
+ import copy
3
+ import math
4
+ import warnings
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import Tensor
10
+ from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList, Parameter
11
+ from torch.nn import functional as F
12
+ from torch.nn.init import constant_, xavier_uniform_
13
+
14
+
15
+ def multi_head_attention_forward(query, # type: Tensor
16
+ key, # type: Tensor
17
+ value, # type: Tensor
18
+ embed_dim_to_check, # type: int
19
+ num_heads, # type: int
20
+ in_proj_weight, # type: Tensor
21
+ in_proj_bias, # type: Tensor
22
+ bias_k, # type: Optional[Tensor]
23
+ bias_v, # type: Optional[Tensor]
24
+ add_zero_attn, # type: bool
25
+ dropout_p, # type: float
26
+ out_proj_weight, # type: Tensor
27
+ out_proj_bias, # type: Tensor
28
+ training=True, # type: bool
29
+ key_padding_mask=None, # type: Optional[Tensor]
30
+ need_weights=True, # type: bool
31
+ attn_mask=None, # type: Optional[Tensor]
32
+ use_separate_proj_weight=False, # type: bool
33
+ q_proj_weight=None, # type: Optional[Tensor]
34
+ k_proj_weight=None, # type: Optional[Tensor]
35
+ v_proj_weight=None, # type: Optional[Tensor]
36
+ static_k=None, # type: Optional[Tensor]
37
+ static_v=None # type: Optional[Tensor]
38
+ ):
39
+ # type: (...) -> Tuple[Tensor, Optional[Tensor]]
40
+ r"""
41
+ Args:
42
+ query, key, value: map a query and a set of key-value pairs to an output.
43
+ See "Attention Is All You Need" for more details.
44
+ embed_dim_to_check: total dimension of the model.
45
+ num_heads: parallel attention heads.
46
+ in_proj_weight, in_proj_bias: input projection weight and bias.
47
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
48
+ add_zero_attn: add a new batch of zeros to the key and
49
+ value sequences at dim=1.
50
+ dropout_p: probability of an element to be zeroed.
51
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
52
+ training: apply dropout if is ``True``.
53
+ key_padding_mask: if provided, specified padding elements in the key will
54
+ be ignored by the attention. This is an binary mask. When the value is True,
55
+ the corresponding value on the attention layer will be filled with -inf.
56
+ need_weights: output attn_output_weights.
57
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
58
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
59
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
60
+ and value in different forms. If false, in_proj_weight will be used, which is
61
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
62
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
63
+ static_k, static_v: static key and value used for attention operators.
64
+ Shape:
65
+ Inputs:
66
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
67
+ the embedding dimension.
68
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
69
+ the embedding dimension.
70
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
71
+ the embedding dimension.
72
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
73
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
74
+ will be unchanged. If a BoolTensor is provided, the positions with the
75
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
76
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
77
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
78
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
79
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
80
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
81
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
82
+ is provided, it will be added to the attention weight.
83
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
84
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
85
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
86
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
87
+ Outputs:
88
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
89
+ E is the embedding dimension.
90
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
91
+ L is the target sequence length, S is the source sequence length.
92
+ """
93
+ # if not torch.jit.is_scripting():
94
+ # tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
95
+ # out_proj_weight, out_proj_bias)
96
+ # if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
97
+ # return handle_torch_function(
98
+ # multi_head_attention_forward, tens_ops, query, key, value,
99
+ # embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
100
+ # bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
101
+ # out_proj_bias, training=training, key_padding_mask=key_padding_mask,
102
+ # need_weights=need_weights, attn_mask=attn_mask,
103
+ # use_separate_proj_weight=use_separate_proj_weight,
104
+ # q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
105
+ # v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
106
+ tgt_len, bsz, embed_dim = query.size()
107
+ assert embed_dim == embed_dim_to_check
108
+ assert key.size() == value.size()
109
+
110
+ head_dim = embed_dim // num_heads
111
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
112
+ scaling = float(head_dim) ** -0.5
113
+
114
+ if not use_separate_proj_weight:
115
+ if torch.equal(query, key) and torch.equal(key, value):
116
+ # self-attention
117
+ q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
118
+
119
+ elif torch.equal(key, value):
120
+ # encoder-decoder attention
121
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
122
+ _b = in_proj_bias
123
+ _start = 0
124
+ _end = embed_dim
125
+ _w = in_proj_weight[_start:_end, :]
126
+ if _b is not None:
127
+ _b = _b[_start:_end]
128
+ q = F.linear(query, _w, _b)
129
+
130
+ if key is None:
131
+ assert value is None
132
+ k = None
133
+ v = None
134
+ else:
135
+
136
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
137
+ _b = in_proj_bias
138
+ _start = embed_dim
139
+ _end = None
140
+ _w = in_proj_weight[_start:, :]
141
+ if _b is not None:
142
+ _b = _b[_start:]
143
+ k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
144
+
145
+ else:
146
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
147
+ _b = in_proj_bias
148
+ _start = 0
149
+ _end = embed_dim
150
+ _w = in_proj_weight[_start:_end, :]
151
+ if _b is not None:
152
+ _b = _b[_start:_end]
153
+ q = F.linear(query, _w, _b)
154
+
155
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
156
+ _b = in_proj_bias
157
+ _start = embed_dim
158
+ _end = embed_dim * 2
159
+ _w = in_proj_weight[_start:_end, :]
160
+ if _b is not None:
161
+ _b = _b[_start:_end]
162
+ k = F.linear(key, _w, _b)
163
+
164
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
165
+ _b = in_proj_bias
166
+ _start = embed_dim * 2
167
+ _end = None
168
+ _w = in_proj_weight[_start:, :]
169
+ if _b is not None:
170
+ _b = _b[_start:]
171
+ v = F.linear(value, _w, _b)
172
+ else:
173
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
174
+ len1, len2 = q_proj_weight_non_opt.size()
175
+ assert len1 == embed_dim and len2 == query.size(-1)
176
+
177
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
178
+ len1, len2 = k_proj_weight_non_opt.size()
179
+ assert len1 == embed_dim and len2 == key.size(-1)
180
+
181
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
182
+ len1, len2 = v_proj_weight_non_opt.size()
183
+ assert len1 == embed_dim and len2 == value.size(-1)
184
+
185
+ if in_proj_bias is not None:
186
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
187
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
188
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
189
+ else:
190
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
191
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
192
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
193
+ q = q * scaling
194
+
195
+ if attn_mask is not None:
196
+ assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
197
+ attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
198
+ 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
199
+ if attn_mask.dtype == torch.uint8:
200
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
201
+ attn_mask = attn_mask.to(torch.bool)
202
+
203
+ if attn_mask.dim() == 2:
204
+ attn_mask = attn_mask.unsqueeze(0)
205
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
206
+ raise RuntimeError('The size of the 2D attn_mask is not correct.')
207
+ elif attn_mask.dim() == 3:
208
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
209
+ raise RuntimeError('The size of the 3D attn_mask is not correct.')
210
+ else:
211
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
212
+ # attn_mask's dim is 3 now.
213
+
214
+ # # convert ByteTensor key_padding_mask to bool
215
+ # if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
216
+ # warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
217
+ # key_padding_mask = key_padding_mask.to(torch.bool)
218
+
219
+ if bias_k is not None and bias_v is not None:
220
+ if static_k is None and static_v is None:
221
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
222
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
223
+ if attn_mask is not None:
224
+ attn_mask = pad(attn_mask, (0, 1))
225
+ if key_padding_mask is not None:
226
+ key_padding_mask = pad(key_padding_mask, (0, 1))
227
+ else:
228
+ assert static_k is None, "bias cannot be added to static key."
229
+ assert static_v is None, "bias cannot be added to static value."
230
+ else:
231
+ assert bias_k is None
232
+ assert bias_v is None
233
+
234
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
235
+ if k is not None:
236
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
237
+ if v is not None:
238
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
239
+
240
+ if static_k is not None:
241
+ assert static_k.size(0) == bsz * num_heads
242
+ assert static_k.size(2) == head_dim
243
+ k = static_k
244
+
245
+ if static_v is not None:
246
+ assert static_v.size(0) == bsz * num_heads
247
+ assert static_v.size(2) == head_dim
248
+ v = static_v
249
+
250
+ src_len = k.size(1)
251
+
252
+ if key_padding_mask is not None:
253
+ assert key_padding_mask.size(0) == bsz
254
+ assert key_padding_mask.size(1) == src_len
255
+
256
+ if add_zero_attn:
257
+ src_len += 1
258
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
259
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
260
+ if attn_mask is not None:
261
+ attn_mask = pad(attn_mask, (0, 1))
262
+ if key_padding_mask is not None:
263
+ key_padding_mask = pad(key_padding_mask, (0, 1))
264
+
265
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
266
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
267
+
268
+ if attn_mask is not None:
269
+ if attn_mask.dtype == torch.bool:
270
+ attn_output_weights.masked_fill_(attn_mask, float('-inf'))
271
+ else:
272
+ attn_output_weights += attn_mask
273
+
274
+
275
+ if key_padding_mask is not None:
276
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
277
+ attn_output_weights = attn_output_weights.masked_fill(
278
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
279
+ float('-inf'),
280
+ )
281
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
282
+
283
+ attn_output_weights = F.softmax(
284
+ attn_output_weights, dim=-1)
285
+ attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
286
+
287
+ attn_output = torch.bmm(attn_output_weights, v)
288
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
289
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
290
+ attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
291
+
292
+ if need_weights:
293
+ # average attention weights over heads
294
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
295
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
296
+ else:
297
+ return attn_output, None
298
+
299
+ class MultiheadAttention(Module):
300
+ r"""Allows the model to jointly attend to information
301
+ from different representation subspaces.
302
+ See reference: Attention Is All You Need
303
+ .. math::
304
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
305
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
306
+ Args:
307
+ embed_dim: total dimension of the model.
308
+ num_heads: parallel attention heads.
309
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
310
+ bias: add bias as module parameter. Default: True.
311
+ add_bias_kv: add bias to the key and value sequences at dim=0.
312
+ add_zero_attn: add a new batch of zeros to the key and
313
+ value sequences at dim=1.
314
+ kdim: total number of features in key. Default: None.
315
+ vdim: total number of features in value. Default: None.
316
+ Note: if kdim and vdim are None, they will be set to embed_dim such that
317
+ query, key, and value have the same number of features.
318
+ Examples::
319
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
320
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
321
+ """
322
+ # __annotations__ = {
323
+ # 'bias_k': torch._jit_internal.Optional[torch.Tensor],
324
+ # 'bias_v': torch._jit_internal.Optional[torch.Tensor],
325
+ # }
326
+ __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']
327
+
328
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
329
+ super(MultiheadAttention, self).__init__()
330
+ self.embed_dim = embed_dim
331
+ self.kdim = kdim if kdim is not None else embed_dim
332
+ self.vdim = vdim if vdim is not None else embed_dim
333
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
334
+
335
+ self.num_heads = num_heads
336
+ self.dropout = dropout
337
+ self.head_dim = embed_dim // num_heads
338
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
339
+
340
+ if self._qkv_same_embed_dim is False:
341
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
342
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
343
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
344
+ self.register_parameter('in_proj_weight', None)
345
+ else:
346
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
347
+ self.register_parameter('q_proj_weight', None)
348
+ self.register_parameter('k_proj_weight', None)
349
+ self.register_parameter('v_proj_weight', None)
350
+
351
+ if bias:
352
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
353
+ else:
354
+ self.register_parameter('in_proj_bias', None)
355
+ self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
356
+
357
+ if add_bias_kv:
358
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
359
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
360
+ else:
361
+ self.bias_k = self.bias_v = None
362
+
363
+ self.add_zero_attn = add_zero_attn
364
+
365
+ self._reset_parameters()
366
+
367
+ def _reset_parameters(self):
368
+ if self._qkv_same_embed_dim:
369
+ xavier_uniform_(self.in_proj_weight)
370
+ else:
371
+ xavier_uniform_(self.q_proj_weight)
372
+ xavier_uniform_(self.k_proj_weight)
373
+ xavier_uniform_(self.v_proj_weight)
374
+
375
+ if self.in_proj_bias is not None:
376
+ constant_(self.in_proj_bias, 0.)
377
+ constant_(self.out_proj.bias, 0.)
378
+ if self.bias_k is not None:
379
+ xavier_normal_(self.bias_k)
380
+ if self.bias_v is not None:
381
+ xavier_normal_(self.bias_v)
382
+
383
+ def __setstate__(self, state):
384
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
385
+ if '_qkv_same_embed_dim' not in state:
386
+ state['_qkv_same_embed_dim'] = True
387
+
388
+ super(MultiheadAttention, self).__setstate__(state)
389
+
390
+ def forward(self, query, key, value, key_padding_mask=None,
391
+ need_weights=True, attn_mask=None):
392
+ # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
393
+ r"""
394
+ Args:
395
+ query, key, value: map a query and a set of key-value pairs to an output.
396
+ See "Attention Is All You Need" for more details.
397
+ key_padding_mask: if provided, specified padding elements in the key will
398
+ be ignored by the attention. This is an binary mask. When the value is True,
399
+ the corresponding value on the attention layer will be filled with -inf.
400
+ need_weights: output attn_output_weights.
401
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
402
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
403
+ Shape:
404
+ - Inputs:
405
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
406
+ the embedding dimension.
407
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
408
+ the embedding dimension.
409
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
410
+ the embedding dimension.
411
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
412
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
413
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
414
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
415
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
416
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
417
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
418
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
419
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
420
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
421
+ is provided, it will be added to the attention weight.
422
+ - Outputs:
423
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
424
+ E is the embedding dimension.
425
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
426
+ L is the target sequence length, S is the source sequence length.
427
+ """
428
+ if not self._qkv_same_embed_dim:
429
+ return multi_head_attention_forward(
430
+ query, key, value, self.embed_dim, self.num_heads,
431
+ self.in_proj_weight, self.in_proj_bias,
432
+ self.bias_k, self.bias_v, self.add_zero_attn,
433
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
434
+ training=self.training,
435
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
436
+ attn_mask=attn_mask, use_separate_proj_weight=True,
437
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
438
+ v_proj_weight=self.v_proj_weight)
439
+ else:
440
+ return multi_head_attention_forward(
441
+ query, key, value, self.embed_dim, self.num_heads,
442
+ self.in_proj_weight, self.in_proj_bias,
443
+ self.bias_k, self.bias_v, self.add_zero_attn,
444
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
445
+ training=self.training,
446
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
447
+ attn_mask=attn_mask)
448
+
449
+
450
+ class Transformer(Module):
451
+ r"""A transformer model. User is able to modify the attributes as needed. The architecture
452
+ is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
453
+ Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
454
+ Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
455
+ Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
456
+ model with corresponding parameters.
457
+
458
+ Args:
459
+ d_model: the number of expected features in the encoder/decoder inputs (default=512).
460
+ nhead: the number of heads in the multiheadattention models (default=8).
461
+ num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
462
+ num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
463
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
464
+ dropout: the dropout value (default=0.1).
465
+ activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
466
+ custom_encoder: custom encoder (default=None).
467
+ custom_decoder: custom decoder (default=None).
468
+
469
+ Examples::
470
+ >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
471
+ >>> src = torch.rand((10, 32, 512))
472
+ >>> tgt = torch.rand((20, 32, 512))
473
+ >>> out = transformer_model(src, tgt)
474
+
475
+ Note: A full example to apply nn.Transformer module for the word language model is available in
476
+ https://github.com/pytorch/examples/tree/master/word_language_model
477
+ """
478
+
479
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
480
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
481
+ activation="relu", custom_encoder=None, custom_decoder=None):
482
+ super(Transformer, self).__init__()
483
+
484
+ if custom_encoder is not None:
485
+ self.encoder = custom_encoder
486
+ else:
487
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
488
+ encoder_norm = LayerNorm(d_model)
489
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
490
+
491
+ if custom_decoder is not None:
492
+ self.decoder = custom_decoder
493
+ else:
494
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
495
+ decoder_norm = LayerNorm(d_model)
496
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
497
+
498
+ self._reset_parameters()
499
+
500
+ self.d_model = d_model
501
+ self.nhead = nhead
502
+
503
+ def forward(self, src, tgt, src_mask=None, tgt_mask=None,
504
+ memory_mask=None, src_key_padding_mask=None,
505
+ tgt_key_padding_mask=None, memory_key_padding_mask=None):
506
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor # noqa
507
+ r"""Take in and process masked source/target sequences.
508
+
509
+ Args:
510
+ src: the sequence to the encoder (required).
511
+ tgt: the sequence to the decoder (required).
512
+ src_mask: the additive mask for the src sequence (optional).
513
+ tgt_mask: the additive mask for the tgt sequence (optional).
514
+ memory_mask: the additive mask for the encoder output (optional).
515
+ src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
516
+ tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
517
+ memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
518
+
519
+ Shape:
520
+ - src: :math:`(S, N, E)`.
521
+ - tgt: :math:`(T, N, E)`.
522
+ - src_mask: :math:`(S, S)`.
523
+ - tgt_mask: :math:`(T, T)`.
524
+ - memory_mask: :math:`(T, S)`.
525
+ - src_key_padding_mask: :math:`(N, S)`.
526
+ - tgt_key_padding_mask: :math:`(N, T)`.
527
+ - memory_key_padding_mask: :math:`(N, S)`.
528
+
529
+ Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
530
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
531
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
532
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
533
+ is provided, it will be added to the attention weight.
534
+ [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
535
+ the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
536
+ positions will be unchanged. If a BoolTensor is provided, the positions with the
537
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
538
+
539
+ - output: :math:`(T, N, E)`.
540
+
541
+ Note: Due to the multi-head attention architecture in the transformer model,
542
+ the output sequence length of a transformer is same as the input sequence
543
+ (i.e. target) length of the decode.
544
+
545
+ where S is the source sequence length, T is the target sequence length, N is the
546
+ batch size, E is the feature number
547
+
548
+ Examples:
549
+ >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
550
+ """
551
+
552
+ if src.size(1) != tgt.size(1):
553
+ raise RuntimeError("the batch number of src and tgt must be equal")
554
+
555
+ if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
556
+ raise RuntimeError("the feature number of src and tgt must be equal to d_model")
557
+
558
+ memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
559
+ output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
560
+ tgt_key_padding_mask=tgt_key_padding_mask,
561
+ memory_key_padding_mask=memory_key_padding_mask)
562
+ return output
563
+
564
+ def generate_square_subsequent_mask(self, sz):
565
+ r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
566
+ Unmasked positions are filled with float(0.0).
567
+ """
568
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
569
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
570
+ return mask
571
+
572
+ def _reset_parameters(self):
573
+ r"""Initiate parameters in the transformer model."""
574
+
575
+ for p in self.parameters():
576
+ if p.dim() > 1:
577
+ xavier_uniform_(p)
578
+
579
+
580
+ class TransformerEncoder(Module):
581
+ r"""TransformerEncoder is a stack of N encoder layers
582
+
583
+ Args:
584
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
585
+ num_layers: the number of sub-encoder-layers in the encoder (required).
586
+ norm: the layer normalization component (optional).
587
+
588
+ Examples::
589
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
590
+ >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
591
+ >>> src = torch.rand(10, 32, 512)
592
+ >>> out = transformer_encoder(src)
593
+ """
594
+ __constants__ = ['norm']
595
+
596
+ def __init__(self, encoder_layer, num_layers, norm=None):
597
+ super(TransformerEncoder, self).__init__()
598
+ self.layers = _get_clones(encoder_layer, num_layers)
599
+ self.num_layers = num_layers
600
+ self.norm = norm
601
+
602
+ def forward(self, src, mask=None, src_key_padding_mask=None):
603
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
604
+ r"""Pass the input through the encoder layers in turn.
605
+
606
+ Args:
607
+ src: the sequence to the encoder (required).
608
+ mask: the mask for the src sequence (optional).
609
+ src_key_padding_mask: the mask for the src keys per batch (optional).
610
+
611
+ Shape:
612
+ see the docs in Transformer class.
613
+ """
614
+ output = src
615
+
616
+ for i, mod in enumerate(self.layers):
617
+ output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
618
+
619
+ if self.norm is not None:
620
+ output = self.norm(output)
621
+
622
+ return output
623
+
624
+
625
+ class TransformerDecoder(Module):
626
+ r"""TransformerDecoder is a stack of N decoder layers
627
+
628
+ Args:
629
+ decoder_layer: an instance of the TransformerDecoderLayer() class (required).
630
+ num_layers: the number of sub-decoder-layers in the decoder (required).
631
+ norm: the layer normalization component (optional).
632
+
633
+ Examples::
634
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
635
+ >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
636
+ >>> memory = torch.rand(10, 32, 512)
637
+ >>> tgt = torch.rand(20, 32, 512)
638
+ >>> out = transformer_decoder(tgt, memory)
639
+ """
640
+ __constants__ = ['norm']
641
+
642
+ def __init__(self, decoder_layer, num_layers, norm=None):
643
+ super(TransformerDecoder, self).__init__()
644
+ self.layers = _get_clones(decoder_layer, num_layers)
645
+ self.num_layers = num_layers
646
+ self.norm = norm
647
+
648
+ def forward(self, tgt, memory, memory2=None, tgt_mask=None,
649
+ memory_mask=None, memory_mask2=None, tgt_key_padding_mask=None,
650
+ memory_key_padding_mask=None, memory_key_padding_mask2=None):
651
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
652
+ r"""Pass the inputs (and mask) through the decoder layer in turn.
653
+
654
+ Args:
655
+ tgt: the sequence to the decoder (required).
656
+ memory: the sequence from the last layer of the encoder (required).
657
+ tgt_mask: the mask for the tgt sequence (optional).
658
+ memory_mask: the mask for the memory sequence (optional).
659
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
660
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
661
+
662
+ Shape:
663
+ see the docs in Transformer class.
664
+ """
665
+ output = tgt
666
+
667
+ for mod in self.layers:
668
+ output = mod(output, memory, memory2=memory2, tgt_mask=tgt_mask,
669
+ memory_mask=memory_mask, memory_mask2=memory_mask2,
670
+ tgt_key_padding_mask=tgt_key_padding_mask,
671
+ memory_key_padding_mask=memory_key_padding_mask,
672
+ memory_key_padding_mask2=memory_key_padding_mask2)
673
+
674
+ if self.norm is not None:
675
+ output = self.norm(output)
676
+
677
+ return output
678
+
679
+ class TransformerEncoderLayer(Module):
680
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
681
+ This standard encoder layer is based on the paper "Attention Is All You Need".
682
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
683
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
684
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
685
+ in a different way during application.
686
+
687
+ Args:
688
+ d_model: the number of expected features in the input (required).
689
+ nhead: the number of heads in the multiheadattention models (required).
690
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
691
+ dropout: the dropout value (default=0.1).
692
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
693
+
694
+ Examples::
695
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
696
+ >>> src = torch.rand(10, 32, 512)
697
+ >>> out = encoder_layer(src)
698
+ """
699
+
700
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
701
+ activation="relu", debug=False):
702
+ super(TransformerEncoderLayer, self).__init__()
703
+ self.debug = debug
704
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
705
+ # Implementation of Feedforward model
706
+ self.linear1 = Linear(d_model, dim_feedforward)
707
+ self.dropout = Dropout(dropout)
708
+ self.linear2 = Linear(dim_feedforward, d_model)
709
+
710
+ self.norm1 = LayerNorm(d_model)
711
+ self.norm2 = LayerNorm(d_model)
712
+ self.dropout1 = Dropout(dropout)
713
+ self.dropout2 = Dropout(dropout)
714
+
715
+ self.activation = _get_activation_fn(activation)
716
+
717
+ def __setstate__(self, state):
718
+ if 'activation' not in state:
719
+ state['activation'] = F.relu
720
+ super(TransformerEncoderLayer, self).__setstate__(state)
721
+
722
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
723
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
724
+ r"""Pass the input through the encoder layer.
725
+
726
+ Args:
727
+ src: the sequence to the encoder layer (required).
728
+ src_mask: the mask for the src sequence (optional).
729
+ src_key_padding_mask: the mask for the src keys per batch (optional).
730
+
731
+ Shape:
732
+ see the docs in Transformer class.
733
+ """
734
+ src2, attn = self.self_attn(src, src, src, attn_mask=src_mask,
735
+ key_padding_mask=src_key_padding_mask)
736
+ if self.debug: self.attn = attn
737
+ src = src + self.dropout1(src2)
738
+ src = self.norm1(src)
739
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
740
+ src = src + self.dropout2(src2)
741
+ src = self.norm2(src)
742
+
743
+ return src
744
+
745
+
746
+ class TransformerDecoderLayer(Module):
747
+ r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
748
+ This standard decoder layer is based on the paper "Attention Is All You Need".
749
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
750
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
751
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
752
+ in a different way during application.
753
+
754
+ Args:
755
+ d_model: the number of expected features in the input (required).
756
+ nhead: the number of heads in the multiheadattention models (required).
757
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
758
+ dropout: the dropout value (default=0.1).
759
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
760
+
761
+ Examples::
762
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
763
+ >>> memory = torch.rand(10, 32, 512)
764
+ >>> tgt = torch.rand(20, 32, 512)
765
+ >>> out = decoder_layer(tgt, memory)
766
+ """
767
+
768
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
769
+ activation="relu", self_attn=True, siamese=False, debug=False):
770
+ super(TransformerDecoderLayer, self).__init__()
771
+ self.has_self_attn, self.siamese = self_attn, siamese
772
+ self.debug = debug
773
+ if self.has_self_attn:
774
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
775
+ self.norm1 = LayerNorm(d_model)
776
+ self.dropout1 = Dropout(dropout)
777
+ self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
778
+ # Implementation of Feedforward model
779
+ self.linear1 = Linear(d_model, dim_feedforward)
780
+ self.dropout = Dropout(dropout)
781
+ self.linear2 = Linear(dim_feedforward, d_model)
782
+
783
+ self.norm2 = LayerNorm(d_model)
784
+ self.norm3 = LayerNorm(d_model)
785
+ self.dropout2 = Dropout(dropout)
786
+ self.dropout3 = Dropout(dropout)
787
+ if self.siamese:
788
+ self.multihead_attn2 = MultiheadAttention(d_model, nhead, dropout=dropout)
789
+
790
+ self.activation = _get_activation_fn(activation)
791
+
792
+ def __setstate__(self, state):
793
+ if 'activation' not in state:
794
+ state['activation'] = F.relu
795
+ super(TransformerDecoderLayer, self).__setstate__(state)
796
+
797
+ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
798
+ tgt_key_padding_mask=None, memory_key_padding_mask=None,
799
+ memory2=None, memory_mask2=None, memory_key_padding_mask2=None):
800
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
801
+ r"""Pass the inputs (and mask) through the decoder layer.
802
+
803
+ Args:
804
+ tgt: the sequence to the decoder layer (required).
805
+ memory: the sequence from the last layer of the encoder (required).
806
+ tgt_mask: the mask for the tgt sequence (optional).
807
+ memory_mask: the mask for the memory sequence (optional).
808
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
809
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
810
+
811
+ Shape:
812
+ see the docs in Transformer class.
813
+ """
814
+ if self.has_self_attn:
815
+ tgt2, attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
816
+ key_padding_mask=tgt_key_padding_mask)
817
+ tgt = tgt + self.dropout1(tgt2)
818
+ tgt = self.norm1(tgt)
819
+ if self.debug: self.attn = attn
820
+ tgt2, attn2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
821
+ key_padding_mask=memory_key_padding_mask)
822
+ if self.debug: self.attn2 = attn2
823
+
824
+ if self.siamese:
825
+ tgt3, attn3 = self.multihead_attn2(tgt, memory2, memory2, attn_mask=memory_mask2,
826
+ key_padding_mask=memory_key_padding_mask2)
827
+ tgt = tgt + self.dropout2(tgt3)
828
+ if self.debug: self.attn3 = attn3
829
+
830
+ tgt = tgt + self.dropout2(tgt2)
831
+ tgt = self.norm2(tgt)
832
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
833
+ tgt = tgt + self.dropout3(tgt2)
834
+ tgt = self.norm3(tgt)
835
+
836
+ return tgt
837
+
838
+
839
+ def _get_clones(module, N):
840
+ return ModuleList([copy.deepcopy(module) for i in range(N)])
841
+
842
+
843
+ def _get_activation_fn(activation):
844
+ if activation == "relu":
845
+ return F.relu
846
+ elif activation == "gelu":
847
+ return F.gelu
848
+
849
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
850
+
851
+
852
+ class PositionalEncoding(nn.Module):
853
+ r"""Inject some information about the relative or absolute position of the tokens
854
+ in the sequence. The positional encodings have the same dimension as
855
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
856
+ functions of different frequencies.
857
+ .. math::
858
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
859
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
860
+ \text{where pos is the word position and i is the embed idx)
861
+ Args:
862
+ d_model: the embed dim (required).
863
+ dropout: the dropout value (default=0.1).
864
+ max_len: the max. length of the incoming sequence (default=5000).
865
+ Examples:
866
+ >>> pos_encoder = PositionalEncoding(d_model)
867
+ """
868
+
869
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
870
+ super(PositionalEncoding, self).__init__()
871
+ self.dropout = nn.Dropout(p=dropout)
872
+
873
+ pe = torch.zeros(max_len, d_model)
874
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
875
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
876
+ pe[:, 0::2] = torch.sin(position * div_term)
877
+ pe[:, 1::2] = torch.cos(position * div_term)
878
+ pe = pe.unsqueeze(0).transpose(0, 1)
879
+ self.register_buffer('pe', pe)
880
+
881
+ def forward(self, x):
882
+ r"""Inputs of forward function
883
+ Args:
884
+ x: the sequence fed to the positional encoder model (required).
885
+ Shape:
886
+ x: [sequence length, batch size, embed dim]
887
+ output: [sequence length, batch size, embed dim]
888
+ Examples:
889
+ >>> output = pos_encoder(x)
890
+ """
891
+
892
+ x = x + self.pe[:x.size(0), :]
893
+ return self.dropout(x)
894
+
895
+
896
+ if __name__ == '__main__':
897
+ transformer_model = Transformer(nhead=16, num_encoder_layers=12)
898
+ src = torch.rand((10, 32, 512))
899
+ tgt = torch.rand((20, 32, 512))
900
+ out = transformer_model(src, tgt)
901
+ print(out)
notebooks/dataset-text.ipynb ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "os.chdir('..')\n",
11
+ "from dataset import *\n",
12
+ "torch.set_printoptions(sci_mode=False)"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "metadata": {},
18
+ "source": [
19
+ "# Construct dataset"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "data = TextDataset('data/Vocabulary_train_v2.csv', is_training=False, smooth_label=True, smooth_factor=0.1)"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "data = DataBunch.create(train_ds=data, valid_ds=None, bs=6)"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "x, y = data.one_batch(); x, y"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "x[0].shape, x[1].shape"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "y[0].shape, y[1].shape"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "x[0].argmax(-1) - y[0].argmax(-1)"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "x[0].argmax(-1)"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "y[0].argmax(-1)"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "x[0][0,0]"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "markdown",
105
+ "metadata": {},
106
+ "source": [
107
+ "# test SpellingMutation"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "probs = {'pn0': 0., 'pn1': 0., 'pn2': 0., 'pt0': 1.0, 'pt1': 1.0}\n",
117
+ "charset = CharsetMapper('data/charset_36.txt')\n",
118
+ "sm = SpellingMutation(charset=charset, **probs)"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "sm('*a-aa')"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": []
136
+ }
137
+ ],
138
+ "metadata": {
139
+ "kernelspec": {
140
+ "display_name": "Python 3",
141
+ "language": "python",
142
+ "name": "python3"
143
+ },
144
+ "language_info": {
145
+ "codemirror_mode": {
146
+ "name": "ipython",
147
+ "version": 3
148
+ },
149
+ "file_extension": ".py",
150
+ "mimetype": "text/x-python",
151
+ "name": "python",
152
+ "nbconvert_exporter": "python",
153
+ "pygments_lexer": "ipython3",
154
+ "version": "3.7.4"
155
+ }
156
+ },
157
+ "nbformat": 4,
158
+ "nbformat_minor": 2
159
+ }
notebooks/dataset.ipynb ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "os.chdir('..')\n",
11
+ "from dataset import *"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {
18
+ "scrolled": false
19
+ },
20
+ "outputs": [],
21
+ "source": [
22
+ "import logging\n",
23
+ "from torchvision.transforms import ToPILImage\n",
24
+ "from torchvision.utils import make_grid\n",
25
+ "from IPython.display import display\n",
26
+ "from torch.utils.data import ConcatDataset\n",
27
+ "charset = CharsetMapper('data/charset_36.txt')"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "def show_all(dl, iter_size=None):\n",
37
+ " if iter_size is None: iter_size = len(dl)\n",
38
+ " for i, item in enumerate(dl):\n",
39
+ " if i >= iter_size:\n",
40
+ " break\n",
41
+ " image = item[0]\n",
42
+ " label = item[1][0]\n",
43
+ " length = item[1][1]\n",
44
+ " print(f'iter {i}:', [charset.get_text(label[j][0: length[j]].argmax(-1), padding=False) for j in range(bs)])\n",
45
+ " display(ToPILImage()(make_grid(item[0].cpu())))"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "metadata": {},
51
+ "source": [
52
+ "# Construct dataset"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "data1 = ImageDataset('data/training/ST', is_training=True);data1 # is_training"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "metadata": {
68
+ "scrolled": true
69
+ },
70
+ "outputs": [],
71
+ "source": [
72
+ "bs=64\n",
73
+ "data2 = ImageDataBunch.create(train_ds=data1, valid_ds=None, bs=bs, num_workers=1);data2"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "#data3 = data2.normalize(imagenet_stats);data3\n",
83
+ "data3 = data2"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "show_all(data3.train_dl, 4)"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "markdown",
97
+ "metadata": {},
98
+ "source": [
99
+ "# Add dataset"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "kwargs = {'data_aug': False, 'is_training': False}"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "data1 = ImageDataset('data/evaluation/IIIT5k_3000', **kwargs);data1"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "metadata": {},
124
+ "outputs": [],
125
+ "source": [
126
+ "data2 = ImageDataset('data/evaluation/SVT', **kwargs);data2"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": null,
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "data3 = ConcatDataset([data1, data2])"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "bs=64\n",
145
+ "data4 = ImageDataBunch.create(train_ds=data1, valid_ds=data3, bs=bs, num_workers=1);data4"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "metadata": {},
152
+ "outputs": [],
153
+ "source": [
154
+ "len(data4.train_dl), len(data4.valid_dl)"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": null,
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": [
163
+ "show_all(data4.train_dl, 4)"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "markdown",
168
+ "metadata": {},
169
+ "source": [
170
+ "# TEST"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "len(data4.valid_dl)"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "metadata": {},
186
+ "outputs": [],
187
+ "source": [
188
+ "import time\n",
189
+ "niter = 1000\n",
190
+ "start = time.time()\n",
191
+ "for i, item in enumerate(progress_bar(data4.valid_dl)):\n",
192
+ " if i % niter == 0 and i > 0:\n",
193
+ " print(i, (time.time() - start) / niter)\n",
194
+ " start = time.time()"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": null,
200
+ "metadata": {
201
+ "scrolled": true
202
+ },
203
+ "outputs": [],
204
+ "source": [
205
+ "num = 20\n",
206
+ "index = 6\n",
207
+ "plt.figure(figsize=(20, 10))\n",
208
+ "for i in range(num):\n",
209
+ " plt.subplot(num // 4, 4, i+1)\n",
210
+ " plt.imshow(data4.train_ds[i][0].data.numpy().transpose(1,2,0))"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": null,
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "def show(path, image_key):\n",
220
+ " with lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False).begin(write=False) as txn:\n",
221
+ " imgbuf = txn.get(image_key.encode()) # image\n",
222
+ " buf = six.BytesIO()\n",
223
+ " buf.write(imgbuf)\n",
224
+ " buf.seek(0)\n",
225
+ " with warnings.catch_warnings():\n",
226
+ " warnings.simplefilter(\"ignore\", UserWarning) # EXIF warning from TiffPlugin\n",
227
+ " x = PIL.Image.open(buf).convert('RGB')\n",
228
+ " print(x.size)\n",
229
+ " plt.imshow(x)"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "image_key = 'image-003118258'\n",
239
+ "image_key = 'image-002780217'\n",
240
+ "image_key = 'image-002780218'\n",
241
+ "path = 'data/CVPR2016'\n",
242
+ "show(path, image_key)"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": null,
248
+ "metadata": {},
249
+ "outputs": [],
250
+ "source": [
251
+ "image_key = 'image-004668347'\n",
252
+ "image_key = 'image-006128516'\n",
253
+ "path = 'data/NIPS2014'\n",
254
+ "show(path, image_key)"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "image_key = 'image-004668347'\n",
264
+ "image_key = 'image-000002420'\n",
265
+ "path = 'data/IIIT5K_3000'\n",
266
+ "show(path, image_key)"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": null,
272
+ "metadata": {},
273
+ "outputs": [],
274
+ "source": []
275
+ }
276
+ ],
277
+ "metadata": {
278
+ "kernelspec": {
279
+ "display_name": "Python 3",
280
+ "language": "python",
281
+ "name": "python3"
282
+ },
283
+ "language_info": {
284
+ "codemirror_mode": {
285
+ "name": "ipython",
286
+ "version": 3
287
+ },
288
+ "file_extension": ".py",
289
+ "mimetype": "text/x-python",
290
+ "name": "python",
291
+ "nbconvert_exporter": "python",
292
+ "pygments_lexer": "ipython3",
293
+ "version": "3.7.4"
294
+ }
295
+ },
296
+ "nbformat": 4,
297
+ "nbformat_minor": 2
298
+ }
notebooks/prepare_wikitext103.ipynb ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 82841986 is_char and is_digit"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "# 82075350 regrex non-ascii and none-digit"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "markdown",
19
+ "metadata": {},
20
+ "source": [
21
+ "## 86460763 left"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 1,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "import os\n",
31
+ "import random\n",
32
+ "import re\n",
33
+ "import pandas as pd"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 2,
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "max_length = 25\n",
43
+ "min_length = 1\n",
44
+ "root = '../data'\n",
45
+ "charset = 'abcdefghijklmnopqrstuvwxyz'\n",
46
+ "digits = '0123456789'"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 3,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "def is_char(text, ratio=0.5):\n",
56
+ " text = text.lower()\n",
57
+ " length = max(len(text), 1)\n",
58
+ " char_num = sum([t in charset for t in text])\n",
59
+ " if char_num < min_length: return False\n",
60
+ " if char_num / length < ratio: return False\n",
61
+ " return True\n",
62
+ "\n",
63
+ "def is_digit(text, ratio=0.5):\n",
64
+ " length = max(len(text), 1)\n",
65
+ " digit_num = sum([t in digits for t in text])\n",
66
+ " if digit_num / length < ratio: return False\n",
67
+ " return True"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "markdown",
72
+ "metadata": {},
73
+ "source": [
74
+ "# generate training dataset"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 4,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "with open('/tmp/wikitext-103/wiki.train.tokens', 'r') as file:\n",
84
+ " lines = file.readlines()"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 5,
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "inp, gt = [], []\n",
94
+ "for line in lines:\n",
95
+ " token = line.lower().split()\n",
96
+ " for text in token:\n",
97
+ " text = re.sub('[^0-9a-zA-Z]+', '', text)\n",
98
+ " if len(text) < min_length:\n",
99
+ " # print('short-text', text)\n",
100
+ " continue\n",
101
+ " if len(text) > max_length:\n",
102
+ " # print('long-text', text)\n",
103
+ " continue\n",
104
+ " inp.append(text)\n",
105
+ " gt.append(text)"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": 6,
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "train_voc = os.path.join(root, 'WikiText-103.csv')\n",
115
+ "pd.DataFrame({'inp':inp, 'gt':gt}).to_csv(train_voc, index=None, sep='\\t')"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": 7,
121
+ "metadata": {},
122
+ "outputs": [
123
+ {
124
+ "data": {
125
+ "text/plain": [
126
+ "86460763"
127
+ ]
128
+ },
129
+ "execution_count": 7,
130
+ "metadata": {},
131
+ "output_type": "execute_result"
132
+ }
133
+ ],
134
+ "source": [
135
+ "len(inp)"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": 8,
141
+ "metadata": {},
142
+ "outputs": [
143
+ {
144
+ "data": {
145
+ "text/plain": [
146
+ "['valkyria',\n",
147
+ " 'chronicles',\n",
148
+ " 'iii',\n",
149
+ " 'senj',\n",
150
+ " 'no',\n",
151
+ " 'valkyria',\n",
152
+ " '3',\n",
153
+ " 'unk',\n",
154
+ " 'chronicles',\n",
155
+ " 'japanese',\n",
156
+ " '3',\n",
157
+ " 'lit',\n",
158
+ " 'valkyria',\n",
159
+ " 'of',\n",
160
+ " 'the',\n",
161
+ " 'battlefield',\n",
162
+ " '3',\n",
163
+ " 'commonly',\n",
164
+ " 'referred',\n",
165
+ " 'to',\n",
166
+ " 'as',\n",
167
+ " 'valkyria',\n",
168
+ " 'chronicles',\n",
169
+ " 'iii',\n",
170
+ " 'outside',\n",
171
+ " 'japan',\n",
172
+ " 'is',\n",
173
+ " 'a',\n",
174
+ " 'tactical',\n",
175
+ " 'role',\n",
176
+ " 'playing',\n",
177
+ " 'video',\n",
178
+ " 'game',\n",
179
+ " 'developed',\n",
180
+ " 'by',\n",
181
+ " 'sega',\n",
182
+ " 'and',\n",
183
+ " 'mediavision',\n",
184
+ " 'for',\n",
185
+ " 'the',\n",
186
+ " 'playstation',\n",
187
+ " 'portable',\n",
188
+ " 'released',\n",
189
+ " 'in',\n",
190
+ " 'january',\n",
191
+ " '2011',\n",
192
+ " 'in',\n",
193
+ " 'japan',\n",
194
+ " 'it',\n",
195
+ " 'is',\n",
196
+ " 'the',\n",
197
+ " 'third',\n",
198
+ " 'game',\n",
199
+ " 'in',\n",
200
+ " 'the',\n",
201
+ " 'valkyria',\n",
202
+ " 'series',\n",
203
+ " 'employing',\n",
204
+ " 'the',\n",
205
+ " 'same',\n",
206
+ " 'fusion',\n",
207
+ " 'of',\n",
208
+ " 'tactical',\n",
209
+ " 'and',\n",
210
+ " 'real',\n",
211
+ " 'time',\n",
212
+ " 'gameplay',\n",
213
+ " 'as',\n",
214
+ " 'its',\n",
215
+ " 'predecessors',\n",
216
+ " 'the',\n",
217
+ " 'story',\n",
218
+ " 'runs',\n",
219
+ " 'parallel',\n",
220
+ " 'to',\n",
221
+ " 'the',\n",
222
+ " 'first',\n",
223
+ " 'game',\n",
224
+ " 'and',\n",
225
+ " 'follows',\n",
226
+ " 'the',\n",
227
+ " 'nameless',\n",
228
+ " 'a',\n",
229
+ " 'penal',\n",
230
+ " 'military',\n",
231
+ " 'unit',\n",
232
+ " 'serving',\n",
233
+ " 'the',\n",
234
+ " 'nation',\n",
235
+ " 'of',\n",
236
+ " 'gallia',\n",
237
+ " 'during',\n",
238
+ " 'the',\n",
239
+ " 'second',\n",
240
+ " 'europan',\n",
241
+ " 'war',\n",
242
+ " 'who',\n",
243
+ " 'perform',\n",
244
+ " 'secret',\n",
245
+ " 'black']"
246
+ ]
247
+ },
248
+ "execution_count": 8,
249
+ "metadata": {},
250
+ "output_type": "execute_result"
251
+ }
252
+ ],
253
+ "source": [
254
+ "inp[:100]"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "markdown",
259
+ "metadata": {},
260
+ "source": [
261
+ "# generate evaluation dataset"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": 9,
267
+ "metadata": {},
268
+ "outputs": [],
269
+ "source": [
270
+ "def disturb(word, degree, p=0.3):\n",
271
+ " if len(word) // 2 < degree: return word\n",
272
+ " if is_digit(word): return word\n",
273
+ " if random.random() < p: return word\n",
274
+ " else:\n",
275
+ " index = list(range(len(word)))\n",
276
+ " random.shuffle(index)\n",
277
+ " index = index[:degree]\n",
278
+ " new_word = []\n",
279
+ " for i in range(len(word)):\n",
280
+ " if i not in index: \n",
281
+ " new_word.append(word[i])\n",
282
+ " continue\n",
283
+ " if (word[i] not in charset) and (word[i] not in digits):\n",
284
+ " # special token\n",
285
+ " new_word.append(word[i])\n",
286
+ " continue\n",
287
+ " op = random.random()\n",
288
+ " if op < 0.1: # add\n",
289
+ " new_word.append(random.choice(charset))\n",
290
+ " new_word.append(word[i])\n",
291
+ " elif op < 0.2: continue # remove\n",
292
+ " else: new_word.append(random.choice(charset)) # replace\n",
293
+ " return ''.join(new_word)"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": 10,
299
+ "metadata": {},
300
+ "outputs": [],
301
+ "source": [
302
+ "lines = inp\n",
303
+ "degree = 1\n",
304
+ "keep_num = 50000\n",
305
+ "\n",
306
+ "random.shuffle(lines)\n",
307
+ "part_lines = lines[:keep_num]\n",
308
+ "inp, gt = [], []\n",
309
+ "\n",
310
+ "for w in part_lines:\n",
311
+ " w = w.strip().lower()\n",
312
+ " new_w = disturb(w, degree)\n",
313
+ " inp.append(new_w)\n",
314
+ " gt.append(w)\n",
315
+ " \n",
316
+ "eval_voc = os.path.join(root, f'WikiText-103_eval_d{degree}.csv')\n",
317
+ "pd.DataFrame({'inp':inp, 'gt':gt}).to_csv(eval_voc, index=None, sep='\\t')"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": 11,
323
+ "metadata": {},
324
+ "outputs": [
325
+ {
326
+ "data": {
327
+ "text/plain": [
328
+ "[('high', 'high'),\n",
329
+ " ('vctoria', 'victoria'),\n",
330
+ " ('mains', 'mains'),\n",
331
+ " ('bi', 'by'),\n",
332
+ " ('13', '13'),\n",
333
+ " ('ticnet', 'ticket'),\n",
334
+ " ('basil', 'basic'),\n",
335
+ " ('cut', 'cut'),\n",
336
+ " ('aqarky', 'anarky'),\n",
337
+ " ('the', 'the'),\n",
338
+ " ('tqe', 'the'),\n",
339
+ " ('oc', 'of'),\n",
340
+ " ('diwpersal', 'dispersal'),\n",
341
+ " ('traffic', 'traffic'),\n",
342
+ " ('in', 'in'),\n",
343
+ " ('the', 'the'),\n",
344
+ " ('ti', 'to'),\n",
345
+ " ('professionalms', 'professionals'),\n",
346
+ " ('747', '747'),\n",
347
+ " ('in', 'in'),\n",
348
+ " ('and', 'and'),\n",
349
+ " ('exezutive', 'executive'),\n",
350
+ " ('n400', 'n400'),\n",
351
+ " ('yusic', 'music'),\n",
352
+ " ('s', 's'),\n",
353
+ " ('henri', 'henry'),\n",
354
+ " ('heard', 'heard'),\n",
355
+ " ('thousand', 'thousand'),\n",
356
+ " ('to', 'to'),\n",
357
+ " ('arhy', 'army'),\n",
358
+ " ('td', 'to'),\n",
359
+ " ('a', 'a'),\n",
360
+ " ('oall', 'hall'),\n",
361
+ " ('qind', 'kind'),\n",
362
+ " ('od', 'on'),\n",
363
+ " ('samfria', 'samaria'),\n",
364
+ " ('driveway', 'driveway'),\n",
365
+ " ('which', 'which'),\n",
366
+ " ('wotk', 'work'),\n",
367
+ " ('ak', 'as'),\n",
368
+ " ('persona', 'persona'),\n",
369
+ " ('s', 's'),\n",
370
+ " ('melbourne', 'melbourne'),\n",
371
+ " ('apong', 'along'),\n",
372
+ " ('fas', 'was'),\n",
373
+ " ('thea', 'then'),\n",
374
+ " ('permcy', 'percy'),\n",
375
+ " ('nnd', 'and'),\n",
376
+ " ('alan', 'alan'),\n",
377
+ " ('13', '13'),\n",
378
+ " ('matteos', 'matters'),\n",
379
+ " ('against', 'against'),\n",
380
+ " ('nefion', 'nexion'),\n",
381
+ " ('held', 'held'),\n",
382
+ " ('negative', 'negative'),\n",
383
+ " ('gogd', 'good'),\n",
384
+ " ('the', 'the'),\n",
385
+ " ('thd', 'the'),\n",
386
+ " ('groening', 'groening'),\n",
387
+ " ('tqe', 'the'),\n",
388
+ " ('cwould', 'would'),\n",
389
+ " ('fb', 'ft'),\n",
390
+ " ('uniten', 'united'),\n",
391
+ " ('kone', 'one'),\n",
392
+ " ('thiy', 'this'),\n",
393
+ " ('lanren', 'lauren'),\n",
394
+ " ('s', 's'),\n",
395
+ " ('thhe', 'the'),\n",
396
+ " ('is', 'is'),\n",
397
+ " ('modep', 'model'),\n",
398
+ " ('weird', 'weird'),\n",
399
+ " ('angwer', 'answer'),\n",
400
+ " ('imprisxnment', 'imprisonment'),\n",
401
+ " ('marpery', 'margery'),\n",
402
+ " ('eventuanly', 'eventually'),\n",
403
+ " ('in', 'in'),\n",
404
+ " ('donnoa', 'donna'),\n",
405
+ " ('ik', 'it'),\n",
406
+ " ('reached', 'reached'),\n",
407
+ " ('at', 'at'),\n",
408
+ " ('excxted', 'excited'),\n",
409
+ " ('ws', 'was'),\n",
410
+ " ('raes', 'rates'),\n",
411
+ " ('the', 'the'),\n",
412
+ " ('firsq', 'first'),\n",
413
+ " ('concluyed', 'concluded'),\n",
414
+ " ('recdorded', 'recorded'),\n",
415
+ " ('fhe', 'the'),\n",
416
+ " ('uegiment', 'regiment'),\n",
417
+ " ('a', 'a'),\n",
418
+ " ('glanes', 'planes'),\n",
419
+ " ('conyrol', 'control'),\n",
420
+ " ('thr', 'the'),\n",
421
+ " ('arrext', 'arrest'),\n",
422
+ " ('bth', 'both'),\n",
423
+ " ('forward', 'forward'),\n",
424
+ " ('allowdd', 'allowed'),\n",
425
+ " ('revealed', 'revealed'),\n",
426
+ " ('mayagement', 'management'),\n",
427
+ " ('normal', 'normal')]"
428
+ ]
429
+ },
430
+ "execution_count": 11,
431
+ "metadata": {},
432
+ "output_type": "execute_result"
433
+ }
434
+ ],
435
+ "source": [
436
+ "list(zip(inp, gt))[:100]"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": null,
442
+ "metadata": {},
443
+ "outputs": [],
444
+ "source": []
445
+ }
446
+ ],
447
+ "metadata": {
448
+ "kernelspec": {
449
+ "display_name": "Python 3",
450
+ "language": "python",
451
+ "name": "python3"
452
+ },
453
+ "language_info": {
454
+ "codemirror_mode": {
455
+ "name": "ipython",
456
+ "version": 3
457
+ },
458
+ "file_extension": ".py",
459
+ "mimetype": "text/x-python",
460
+ "name": "python",
461
+ "nbconvert_exporter": "python",
462
+ "pygments_lexer": "ipython3",
463
+ "version": "3.7.4"
464
+ }
465
+ },
466
+ "nbformat": 4,
467
+ "nbformat_minor": 4
468
+ }
notebooks/transforms.ipynb ADDED
The diff for this file is too large to render. See raw diff
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ torch==1.1.0
2
+ torchvision==0.3.0
3
+ fastai==1.0.60
4
+ LMDB
5
+ Pillow
6
+ opencv-python
7
+ tensorboardX
8
+ PyYAML
9
+ gdown
tools/create_lmdb_dataset.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """
2
+
3
+ import fire
4
+ import os
5
+ import lmdb
6
+ import cv2
7
+
8
+ import numpy as np
9
+
10
+
11
+ def checkImageIsValid(imageBin):
12
+ if imageBin is None:
13
+ return False
14
+ imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
15
+ img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
16
+ imgH, imgW = img.shape[0], img.shape[1]
17
+ if imgH * imgW == 0:
18
+ return False
19
+ return True
20
+
21
+
22
+ def writeCache(env, cache):
23
+ with env.begin(write=True) as txn:
24
+ for k, v in cache.items():
25
+ txn.put(k, v)
26
+
27
+
28
+ def createDataset(inputPath, gtFile, outputPath, checkValid=True):
29
+ """
30
+ Create LMDB dataset for training and evaluation.
31
+ ARGS:
32
+ inputPath : input folder path where starts imagePath
33
+ outputPath : LMDB output path
34
+ gtFile : list of image path and label
35
+ checkValid : if true, check the validity of every image
36
+ """
37
+ os.makedirs(outputPath, exist_ok=True)
38
+ env = lmdb.open(outputPath, map_size=1099511627776)
39
+ cache = {}
40
+ cnt = 1
41
+
42
+ with open(gtFile, 'r', encoding='utf-8') as data:
43
+ datalist = data.readlines()
44
+
45
+ nSamples = len(datalist)
46
+ for i in range(nSamples):
47
+ imagePath, label = datalist[i].strip('\n').split('\t')
48
+ imagePath = os.path.join(inputPath, imagePath)
49
+
50
+ # # only use alphanumeric data
51
+ # if re.search('[^a-zA-Z0-9]', label):
52
+ # continue
53
+
54
+ if not os.path.exists(imagePath):
55
+ print('%s does not exist' % imagePath)
56
+ continue
57
+ with open(imagePath, 'rb') as f:
58
+ imageBin = f.read()
59
+ if checkValid:
60
+ try:
61
+ if not checkImageIsValid(imageBin):
62
+ print('%s is not a valid image' % imagePath)
63
+ continue
64
+ except:
65
+ print('error occured', i)
66
+ with open(outputPath + '/error_image_log.txt', 'a') as log:
67
+ log.write('%s-th image data occured error\n' % str(i))
68
+ continue
69
+
70
+ imageKey = 'image-%09d'.encode() % cnt
71
+ labelKey = 'label-%09d'.encode() % cnt
72
+ cache[imageKey] = imageBin
73
+ cache[labelKey] = label.encode()
74
+
75
+ if cnt % 1000 == 0:
76
+ writeCache(env, cache)
77
+ cache = {}
78
+ print('Written %d / %d' % (cnt, nSamples))
79
+ cnt += 1
80
+ nSamples = cnt-1
81
+ cache['num-samples'.encode()] = str(nSamples).encode()
82
+ writeCache(env, cache)
83
+ print('Created dataset with %d samples' % nSamples)
84
+
85
+
86
+ if __name__ == '__main__':
87
+ fire.Fire(createDataset)
tools/crop_by_word_bb_syn90k.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Crop by word bounding box
2
+ # Locate script with gt.mat
3
+ # $ python crop_by_word_bb.py
4
+
5
+ import os
6
+ import re
7
+ import cv2
8
+ import scipy.io as sio
9
+ from itertools import chain
10
+ import numpy as np
11
+ import math
12
+
13
+ mat_contents = sio.loadmat('gt.mat')
14
+
15
+ image_names = mat_contents['imnames'][0]
16
+ cropped_indx = 0
17
+ start_img_indx = 0
18
+ gt_file = open('gt_oabc.txt', 'a')
19
+ err_file = open('err_oabc.txt', 'a')
20
+
21
+ for img_indx in range(start_img_indx, len(image_names)):
22
+
23
+
24
+ # Get image name
25
+ image_name_new = image_names[img_indx][0]
26
+ # print(image_name_new)
27
+ image_name = '/home/yxwang/pytorch/dataset/SynthText/img/'+ image_name_new
28
+ # print('IMAGE : {}.{}'.format(img_indx, image_name))
29
+ print('evaluating {} image'.format(img_indx), end='\r')
30
+ # Get text in image
31
+ txt = mat_contents['txt'][0][img_indx]
32
+ txt = [re.split(' \n|\n |\n| ', t.strip()) for t in txt]
33
+ txt = list(chain(*txt))
34
+ txt = [t for t in txt if len(t) > 0 ]
35
+ # print(txt) # ['Lines:', 'I', 'lost', 'Kevin', 'will', 'line', 'and', 'and', 'the', '(and', 'the', 'out', 'you', "don't", 'pkg']
36
+ # assert 1<0
37
+
38
+ # Open image
39
+ #img = Image.open(image_name)
40
+ img = cv2.imread(image_name, cv2.IMREAD_COLOR)
41
+ img_height, img_width, _ = img.shape
42
+
43
+ # Validation
44
+ if len(np.shape(mat_contents['wordBB'][0][img_indx])) == 2:
45
+ wordBBlen = 1
46
+ else:
47
+ wordBBlen = mat_contents['wordBB'][0][img_indx].shape[-1]
48
+
49
+ if wordBBlen == len(txt):
50
+ # Crop image and save
51
+ for word_indx in range(len(txt)):
52
+ # print('txt--',txt)
53
+ txt_temp = txt[word_indx]
54
+ len_now = len(txt_temp)
55
+ # txt_temp = re.sub('[^0-9a-zA-Z]+', '', txt_temp)
56
+ # print('txt_temp-1-',txt_temp)
57
+ txt_temp = re.sub('[^a-zA-Z]+', '', txt_temp)
58
+ # print('txt_temp-2-',txt_temp)
59
+ if len_now - len(txt_temp) != 0:
60
+ print('txt_temp-2-', txt_temp)
61
+
62
+ if len(np.shape(mat_contents['wordBB'][0][img_indx])) == 2: # only one word (2,4)
63
+ wordBB = mat_contents['wordBB'][0][img_indx]
64
+ else: # many words (2,4,num_words)
65
+ wordBB = mat_contents['wordBB'][0][img_indx][:, :, word_indx]
66
+
67
+ if np.shape(wordBB) != (2, 4):
68
+ err_log = 'malformed box index: {}\t{}\t{}\n'.format(image_name, txt[word_indx], wordBB)
69
+ err_file.write(err_log)
70
+ # print(err_log)
71
+ continue
72
+
73
+ pts1 = np.float32([[wordBB[0][0], wordBB[1][0]],
74
+ [wordBB[0][3], wordBB[1][3]],
75
+ [wordBB[0][1], wordBB[1][1]],
76
+ [wordBB[0][2], wordBB[1][2]]])
77
+ height = math.sqrt((wordBB[0][0] - wordBB[0][3])**2 + (wordBB[1][0] - wordBB[1][3])**2)
78
+ width = math.sqrt((wordBB[0][0] - wordBB[0][1])**2 + (wordBB[1][0] - wordBB[1][1])**2)
79
+
80
+ # Coord validation check
81
+ if (height * width) <= 0:
82
+ err_log = 'empty file : {}\t{}\t{}\n'.format(image_name, txt[word_indx], wordBB)
83
+ err_file.write(err_log)
84
+ # print(err_log)
85
+ continue
86
+ elif (height * width) > (img_height * img_width):
87
+ err_log = 'too big box : {}\t{}\t{}\n'.format(image_name, txt[word_indx], wordBB)
88
+ err_file.write(err_log)
89
+ # print(err_log)
90
+ continue
91
+ else:
92
+ valid = True
93
+ for i in range(2):
94
+ for j in range(4):
95
+ if wordBB[i][j] < 0 or wordBB[i][j] > img.shape[1 - i]:
96
+ valid = False
97
+ break
98
+ if not valid:
99
+ break
100
+ if not valid:
101
+ err_log = 'invalid coord : {}\t{}\t{}\t{}\t{}\n'.format(
102
+ image_name, txt[word_indx], wordBB, (width, height), (img_width, img_height))
103
+ err_file.write(err_log)
104
+ # print(err_log)
105
+ continue
106
+
107
+ pts2 = np.float32([[0, 0],
108
+ [0, height],
109
+ [width, 0],
110
+ [width, height]])
111
+
112
+ x_min = np.int(round(min(wordBB[0][0], wordBB[0][1], wordBB[0][2], wordBB[0][3])))
113
+ x_max = np.int(round(max(wordBB[0][0], wordBB[0][1], wordBB[0][2], wordBB[0][3])))
114
+ y_min = np.int(round(min(wordBB[1][0], wordBB[1][1], wordBB[1][2], wordBB[1][3])))
115
+ y_max = np.int(round(max(wordBB[1][0], wordBB[1][1], wordBB[1][2], wordBB[1][3])))
116
+ # print(x_min, x_max, y_min, y_max)
117
+ # print(img.shape)
118
+ # assert 1<0
119
+ if len(img.shape) == 3:
120
+ img_cropped = img[ y_min:y_max:1, x_min:x_max:1, :]
121
+ else:
122
+ img_cropped = img[ y_min:y_max:1, x_min:x_max:1]
123
+ dir_name = '/home/yxwang/pytorch/dataset/SynthText/cropped-oabc/{}'.format(image_name_new.split('/')[0])
124
+ # print('dir_name--',dir_name)
125
+ if not os.path.exists(dir_name):
126
+ os.mkdir(dir_name)
127
+ cropped_file_name = "{}/{}_{}_{}.jpg".format(dir_name, cropped_indx,
128
+ image_name.split('/')[-1][:-len('.jpg')], word_indx)
129
+ # print('cropped_file_name--',cropped_file_name)
130
+ # print('img_cropped--',img_cropped.shape)
131
+ if img_cropped.shape[0] == 0 or img_cropped.shape[1] == 0:
132
+ err_log = 'word_box_mismatch : {}\t{}\t{}\n'.format(image_name, mat_contents['txt'][0][
133
+ img_indx], mat_contents['wordBB'][0][img_indx])
134
+ err_file.write(err_log)
135
+ # print(err_log)
136
+ continue
137
+ # print('img_cropped--',img_cropped)
138
+
139
+ # img_cropped.save(cropped_file_name)
140
+ cv2.imwrite(cropped_file_name, img_cropped)
141
+ cropped_indx += 1
142
+ gt_file.write('%s\t%s\n' % (cropped_file_name, txt[word_indx]))
143
+
144
+ # if cropped_indx>10:
145
+ # assert 1<0
146
+ # assert 1 < 0
147
+ else:
148
+ err_log = 'word_box_mismatch : {}\t{}\t{}\n'.format(image_name, mat_contents['txt'][0][
149
+ img_indx], mat_contents['wordBB'][0][img_indx])
150
+ err_file.write(err_log)
151
+ # print(err_log)
152
+ gt_file.close()
153
+ err_file.close()
transforms.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numbers
3
+ import random
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ from torchvision.transforms import Compose
10
+
11
+
12
+ def sample_asym(magnitude, size=None):
13
+ return np.random.beta(1, 4, size) * magnitude
14
+
15
+ def sample_sym(magnitude, size=None):
16
+ return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
17
+
18
+ def sample_uniform(low, high, size=None):
19
+ return np.random.uniform(low, high, size=size)
20
+
21
+ def get_interpolation(type='random'):
22
+ if type == 'random':
23
+ choice = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA]
24
+ interpolation = choice[random.randint(0, len(choice)-1)]
25
+ elif type == 'nearest': interpolation = cv2.INTER_NEAREST
26
+ elif type == 'linear': interpolation = cv2.INTER_LINEAR
27
+ elif type == 'cubic': interpolation = cv2.INTER_CUBIC
28
+ elif type == 'area': interpolation = cv2.INTER_AREA
29
+ else: raise TypeError('Interpolation types only nearest, linear, cubic, area are supported!')
30
+ return interpolation
31
+
32
+ class CVRandomRotation(object):
33
+ def __init__(self, degrees=15):
34
+ assert isinstance(degrees, numbers.Number), "degree should be a single number."
35
+ assert degrees >= 0, "degree must be positive."
36
+ self.degrees = degrees
37
+
38
+ @staticmethod
39
+ def get_params(degrees):
40
+ return sample_sym(degrees)
41
+
42
+ def __call__(self, img):
43
+ angle = self.get_params(self.degrees)
44
+ src_h, src_w = img.shape[:2]
45
+ M = cv2.getRotationMatrix2D(center=(src_w/2, src_h/2), angle=angle, scale=1.0)
46
+ abs_cos, abs_sin = abs(M[0,0]), abs(M[0,1])
47
+ dst_w = int(src_h * abs_sin + src_w * abs_cos)
48
+ dst_h = int(src_h * abs_cos + src_w * abs_sin)
49
+ M[0, 2] += (dst_w - src_w)/2
50
+ M[1, 2] += (dst_h - src_h)/2
51
+
52
+ flags = get_interpolation()
53
+ return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE)
54
+
55
+ class CVRandomAffine(object):
56
+ def __init__(self, degrees, translate=None, scale=None, shear=None):
57
+ assert isinstance(degrees, numbers.Number), "degree should be a single number."
58
+ assert degrees >= 0, "degree must be positive."
59
+ self.degrees = degrees
60
+
61
+ if translate is not None:
62
+ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
63
+ "translate should be a list or tuple and it must be of length 2."
64
+ for t in translate:
65
+ if not (0.0 <= t <= 1.0):
66
+ raise ValueError("translation values should be between 0 and 1")
67
+ self.translate = translate
68
+
69
+ if scale is not None:
70
+ assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
71
+ "scale should be a list or tuple and it must be of length 2."
72
+ for s in scale:
73
+ if s <= 0:
74
+ raise ValueError("scale values should be positive")
75
+ self.scale = scale
76
+
77
+ if shear is not None:
78
+ if isinstance(shear, numbers.Number):
79
+ if shear < 0:
80
+ raise ValueError("If shear is a single number, it must be positive.")
81
+ self.shear = [shear]
82
+ else:
83
+ assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
84
+ "shear should be a list or tuple and it must be of length 2."
85
+ self.shear = shear
86
+ else:
87
+ self.shear = shear
88
+
89
+ def _get_inverse_affine_matrix(self, center, angle, translate, scale, shear):
90
+ # https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717
91
+ from numpy import sin, cos, tan
92
+
93
+ if isinstance(shear, numbers.Number):
94
+ shear = [shear, 0]
95
+
96
+ if not isinstance(shear, (tuple, list)) and len(shear) == 2:
97
+ raise ValueError(
98
+ "Shear should be a single value or a tuple/list containing " +
99
+ "two values. Got {}".format(shear))
100
+
101
+ rot = math.radians(angle)
102
+ sx, sy = [math.radians(s) for s in shear]
103
+
104
+ cx, cy = center
105
+ tx, ty = translate
106
+
107
+ # RSS without scaling
108
+ a = cos(rot - sy) / cos(sy)
109
+ b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
110
+ c = sin(rot - sy) / cos(sy)
111
+ d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
112
+
113
+ # Inverted rotation matrix with scale and shear
114
+ # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
115
+ M = [d, -b, 0,
116
+ -c, a, 0]
117
+ M = [x / scale for x in M]
118
+
119
+ # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
120
+ M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
121
+ M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
122
+
123
+ # Apply center translation: C * RSS^-1 * C^-1 * T^-1
124
+ M[2] += cx
125
+ M[5] += cy
126
+ return M
127
+
128
+ @staticmethod
129
+ def get_params(degrees, translate, scale_ranges, shears, height):
130
+ angle = sample_sym(degrees)
131
+ if translate is not None:
132
+ max_dx = translate[0] * height
133
+ max_dy = translate[1] * height
134
+ translations = (np.round(sample_sym(max_dx)), np.round(sample_sym(max_dy)))
135
+ else:
136
+ translations = (0, 0)
137
+
138
+ if scale_ranges is not None:
139
+ scale = sample_uniform(scale_ranges[0], scale_ranges[1])
140
+ else:
141
+ scale = 1.0
142
+
143
+ if shears is not None:
144
+ if len(shears) == 1:
145
+ shear = [sample_sym(shears[0]), 0.]
146
+ elif len(shears) == 2:
147
+ shear = [sample_sym(shears[0]), sample_sym(shears[1])]
148
+ else:
149
+ shear = 0.0
150
+
151
+ return angle, translations, scale, shear
152
+
153
+
154
+ def __call__(self, img):
155
+ src_h, src_w = img.shape[:2]
156
+ angle, translate, scale, shear = self.get_params(
157
+ self.degrees, self.translate, self.scale, self.shear, src_h)
158
+
159
+ M = self._get_inverse_affine_matrix((src_w/2, src_h/2), angle, (0, 0), scale, shear)
160
+ M = np.array(M).reshape(2,3)
161
+
162
+ startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1), (0, src_h - 1)]
163
+ project = lambda x, y, a, b, c: int(a*x + b*y + c)
164
+ endpoints = [(project(x, y, *M[0]), project(x, y, *M[1])) for x, y in startpoints]
165
+
166
+ rect = cv2.minAreaRect(np.array(endpoints))
167
+ bbox = cv2.boxPoints(rect).astype(dtype=np.int)
168
+ max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
169
+ min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
170
+
171
+ dst_w = int(max_x - min_x)
172
+ dst_h = int(max_y - min_y)
173
+ M[0, 2] += (dst_w - src_w) / 2
174
+ M[1, 2] += (dst_h - src_h) / 2
175
+
176
+ # add translate
177
+ dst_w += int(abs(translate[0]))
178
+ dst_h += int(abs(translate[1]))
179
+ if translate[0] < 0: M[0, 2] += abs(translate[0])
180
+ if translate[1] < 0: M[1, 2] += abs(translate[1])
181
+
182
+ flags = get_interpolation()
183
+ return cv2.warpAffine(img, M, (dst_w , dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE)
184
+
185
+ class CVRandomPerspective(object):
186
+ def __init__(self, distortion=0.5):
187
+ self.distortion = distortion
188
+
189
+ def get_params(self, width, height, distortion):
190
+ offset_h = sample_asym(distortion * height / 2, size=4).astype(dtype=np.int)
191
+ offset_w = sample_asym(distortion * width / 2, size=4).astype(dtype=np.int)
192
+ topleft = ( offset_w[0], offset_h[0])
193
+ topright = (width - 1 - offset_w[1], offset_h[1])
194
+ botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
195
+ botleft = ( offset_w[3], height - 1 - offset_h[3])
196
+
197
+ startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
198
+ endpoints = [topleft, topright, botright, botleft]
199
+ return np.array(startpoints, dtype=np.float32), np.array(endpoints, dtype=np.float32)
200
+
201
+ def __call__(self, img):
202
+ height, width = img.shape[:2]
203
+ startpoints, endpoints = self.get_params(width, height, self.distortion)
204
+ M = cv2.getPerspectiveTransform(startpoints, endpoints)
205
+
206
+ # TODO: more robust way to crop image
207
+ rect = cv2.minAreaRect(endpoints)
208
+ bbox = cv2.boxPoints(rect).astype(dtype=np.int)
209
+ max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
210
+ min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
211
+ min_x, min_y = max(min_x, 0), max(min_y, 0)
212
+
213
+ flags = get_interpolation()
214
+ img = cv2.warpPerspective(img, M, (max_x, max_y), flags=flags, borderMode=cv2.BORDER_REPLICATE)
215
+ img = img[min_y:, min_x:]
216
+ return img
217
+
218
+ class CVRescale(object):
219
+
220
+ def __init__(self, factor=4, base_size=(128, 512)):
221
+ """ Define image scales using gaussian pyramid and rescale image to target scale.
222
+
223
+ Args:
224
+ factor: the decayed factor from base size, factor=4 keeps target scale by default.
225
+ base_size: base size the build the bottom layer of pyramid
226
+ """
227
+ if isinstance(factor, numbers.Number):
228
+ self.factor = round(sample_uniform(0, factor))
229
+ elif isinstance(factor, (tuple, list)) and len(factor) == 2:
230
+ self.factor = round(sample_uniform(factor[0], factor[1]))
231
+ else:
232
+ raise Exception('factor must be number or list with length 2')
233
+ # assert factor is valid
234
+ self.base_h, self.base_w = base_size[:2]
235
+
236
+ def __call__(self, img):
237
+ if self.factor == 0: return img
238
+ src_h, src_w = img.shape[:2]
239
+ cur_w, cur_h = self.base_w, self.base_h
240
+ scale_img = cv2.resize(img, (cur_w, cur_h), interpolation=get_interpolation())
241
+ for _ in range(self.factor):
242
+ scale_img = cv2.pyrDown(scale_img)
243
+ scale_img = cv2.resize(scale_img, (src_w, src_h), interpolation=get_interpolation())
244
+ return scale_img
245
+
246
+ class CVGaussianNoise(object):
247
+ def __init__(self, mean=0, var=20):
248
+ self.mean = mean
249
+ if isinstance(var, numbers.Number):
250
+ self.var = max(int(sample_asym(var)), 1)
251
+ elif isinstance(var, (tuple, list)) and len(var) == 2:
252
+ self.var = int(sample_uniform(var[0], var[1]))
253
+ else:
254
+ raise Exception('degree must be number or list with length 2')
255
+
256
+ def __call__(self, img):
257
+ noise = np.random.normal(self.mean, self.var**0.5, img.shape)
258
+ img = np.clip(img + noise, 0, 255).astype(np.uint8)
259
+ return img
260
+
261
+ class CVMotionBlur(object):
262
+ def __init__(self, degrees=12, angle=90):
263
+ if isinstance(degrees, numbers.Number):
264
+ self.degree = max(int(sample_asym(degrees)), 1)
265
+ elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
266
+ self.degree = int(sample_uniform(degrees[0], degrees[1]))
267
+ else:
268
+ raise Exception('degree must be number or list with length 2')
269
+ self.angle = sample_uniform(-angle, angle)
270
+
271
+ def __call__(self, img):
272
+ M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2), self.angle, 1)
273
+ motion_blur_kernel = np.zeros((self.degree, self.degree))
274
+ motion_blur_kernel[self.degree // 2, :] = 1
275
+ motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (self.degree, self.degree))
276
+ motion_blur_kernel = motion_blur_kernel / self.degree
277
+ img = cv2.filter2D(img, -1, motion_blur_kernel)
278
+ img = np.clip(img, 0, 255).astype(np.uint8)
279
+ return img
280
+
281
+ class CVGeometry(object):
282
+ def __init__(self, degrees=15, translate=(0.3, 0.3), scale=(0.5, 2.),
283
+ shear=(45, 15), distortion=0.5, p=0.5):
284
+ self.p = p
285
+ type_p = random.random()
286
+ if type_p < 0.33:
287
+ self.transforms = CVRandomRotation(degrees=degrees)
288
+ elif type_p < 0.66:
289
+ self.transforms = CVRandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear)
290
+ else:
291
+ self.transforms = CVRandomPerspective(distortion=distortion)
292
+
293
+ def __call__(self, img):
294
+ if random.random() < self.p:
295
+ img = np.array(img)
296
+ return Image.fromarray(self.transforms(img))
297
+ else: return img
298
+
299
+ class CVDeterioration(object):
300
+ def __init__(self, var, degrees, factor, p=0.5):
301
+ self.p = p
302
+ transforms = []
303
+ if var is not None:
304
+ transforms.append(CVGaussianNoise(var=var))
305
+ if degrees is not None:
306
+ transforms.append(CVMotionBlur(degrees=degrees))
307
+ if factor is not None:
308
+ transforms.append(CVRescale(factor=factor))
309
+
310
+ random.shuffle(transforms)
311
+ transforms = Compose(transforms)
312
+ self.transforms = transforms
313
+
314
+ def __call__(self, img):
315
+ if random.random() < self.p:
316
+ img = np.array(img)
317
+ return Image.fromarray(self.transforms(img))
318
+ else: return img
319
+
320
+
321
+ class CVColorJitter(object):
322
+ def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.5):
323
+ self.p = p
324
+ self.transforms = transforms.ColorJitter(brightness=brightness, contrast=contrast,
325
+ saturation=saturation, hue=hue)
326
+
327
+ def __call__(self, img):
328
+ if random.random() < self.p: return self.transforms(img)
329
+ else: return img
utils.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import yaml
9
+ from matplotlib import colors
10
+ from matplotlib import pyplot as plt
11
+ from torch import Tensor, nn
12
+ from torch.utils.data import ConcatDataset
13
+
14
+ class CharsetMapper(object):
15
+ """A simple class to map ids into strings.
16
+
17
+ It works only when the character set is 1:1 mapping between individual
18
+ characters and individual ids.
19
+ """
20
+
21
+ def __init__(self,
22
+ filename='',
23
+ max_length=30,
24
+ null_char=u'\u2591'):
25
+ """Creates a lookup table.
26
+
27
+ Args:
28
+ filename: Path to charset file which maps characters to ids.
29
+ max_sequence_length: The max length of ids and string.
30
+ null_char: A unicode character used to replace '<null>' character.
31
+ the default value is a light shade block '░'.
32
+ """
33
+ self.null_char = null_char
34
+ self.max_length = max_length
35
+
36
+ self.label_to_char = self._read_charset(filename)
37
+ self.char_to_label = dict(map(reversed, self.label_to_char.items()))
38
+ self.num_classes = len(self.label_to_char)
39
+
40
+ def _read_charset(self, filename):
41
+ """Reads a charset definition from a tab separated text file.
42
+
43
+ Args:
44
+ filename: a path to the charset file.
45
+
46
+ Returns:
47
+ a dictionary with keys equal to character codes and values - unicode
48
+ characters.
49
+ """
50
+ import re
51
+ pattern = re.compile(r'(\d+)\t(.+)')
52
+ charset = {}
53
+ self.null_label = 0
54
+ charset[self.null_label] = self.null_char
55
+ with open(filename, 'r') as f:
56
+ for i, line in enumerate(f):
57
+ m = pattern.match(line)
58
+ assert m, f'Incorrect charset file. line #{i}: {line}'
59
+ label = int(m.group(1)) + 1
60
+ char = m.group(2)
61
+ charset[label] = char
62
+ return charset
63
+
64
+ def trim(self, text):
65
+ assert isinstance(text, str)
66
+ return text.replace(self.null_char, '')
67
+
68
+ def get_text(self, labels, length=None, padding=True, trim=False):
69
+ """ Returns a string corresponding to a sequence of character ids.
70
+ """
71
+ length = length if length else self.max_length
72
+ labels = [l.item() if isinstance(l, Tensor) else int(l) for l in labels]
73
+ if padding:
74
+ labels = labels + [self.null_label] * (length-len(labels))
75
+ text = ''.join([self.label_to_char[label] for label in labels])
76
+ if trim: text = self.trim(text)
77
+ return text
78
+
79
+ def get_labels(self, text, length=None, padding=True, case_sensitive=False):
80
+ """ Returns the labels of the corresponding text.
81
+ """
82
+ length = length if length else self.max_length
83
+ if padding:
84
+ text = text + self.null_char * (length - len(text))
85
+ if not case_sensitive:
86
+ text = text.lower()
87
+ labels = [self.char_to_label[char] for char in text]
88
+ return labels
89
+
90
+ def pad_labels(self, labels, length=None):
91
+ length = length if length else self.max_length
92
+
93
+ return labels + [self.null_label] * (length - len(labels))
94
+
95
+ @property
96
+ def digits(self):
97
+ return '0123456789'
98
+
99
+ @property
100
+ def digit_labels(self):
101
+ return self.get_labels(self.digits, padding=False)
102
+
103
+ @property
104
+ def alphabets(self):
105
+ all_chars = list(self.char_to_label.keys())
106
+ valid_chars = []
107
+ for c in all_chars:
108
+ if c in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ':
109
+ valid_chars.append(c)
110
+ return ''.join(valid_chars)
111
+
112
+ @property
113
+ def alphabet_labels(self):
114
+ return self.get_labels(self.alphabets, padding=False)
115
+
116
+
117
+ class Timer(object):
118
+ """A simple timer."""
119
+ def __init__(self):
120
+ self.data_time = 0.
121
+ self.data_diff = 0.
122
+ self.data_total_time = 0.
123
+ self.data_call = 0
124
+ self.running_time = 0.
125
+ self.running_diff = 0.
126
+ self.running_total_time = 0.
127
+ self.running_call = 0
128
+
129
+ def tic(self):
130
+ self.start_time = time.time()
131
+ self.running_time = self.start_time
132
+
133
+ def toc_data(self):
134
+ self.data_time = time.time()
135
+ self.data_diff = self.data_time - self.running_time
136
+ self.data_total_time += self.data_diff
137
+ self.data_call += 1
138
+
139
+ def toc_running(self):
140
+ self.running_time = time.time()
141
+ self.running_diff = self.running_time - self.data_time
142
+ self.running_total_time += self.running_diff
143
+ self.running_call += 1
144
+
145
+ def total_time(self):
146
+ return self.data_total_time + self.running_total_time
147
+
148
+ def average_time(self):
149
+ return self.average_data_time() + self.average_running_time()
150
+
151
+ def average_data_time(self):
152
+ return self.data_total_time / (self.data_call or 1)
153
+
154
+ def average_running_time(self):
155
+ return self.running_total_time / (self.running_call or 1)
156
+
157
+
158
+ class Logger(object):
159
+ _handle = None
160
+ _root = None
161
+
162
+ @staticmethod
163
+ def init(output_dir, name, phase):
164
+ format = '[%(asctime)s %(filename)s:%(lineno)d %(levelname)s {}] ' \
165
+ '%(message)s'.format(name)
166
+ logging.basicConfig(level=logging.INFO, format=format)
167
+
168
+ try: os.makedirs(output_dir)
169
+ except: pass
170
+ config_path = os.path.join(output_dir, f'{phase}.txt')
171
+ Logger._handle = logging.FileHandler(config_path)
172
+ Logger._root = logging.getLogger()
173
+
174
+ @staticmethod
175
+ def enable_file():
176
+ if Logger._handle is None or Logger._root is None:
177
+ raise Exception('Invoke Logger.init() first!')
178
+ Logger._root.addHandler(Logger._handle)
179
+
180
+ @staticmethod
181
+ def disable_file():
182
+ if Logger._handle is None or Logger._root is None:
183
+ raise Exception('Invoke Logger.init() first!')
184
+ Logger._root.removeHandler(Logger._handle)
185
+
186
+
187
+ class Config(object):
188
+
189
+ def __init__(self, config_path, host=True):
190
+ def __dict2attr(d, prefix=''):
191
+ for k, v in d.items():
192
+ if isinstance(v, dict):
193
+ __dict2attr(v, f'{prefix}{k}_')
194
+ else:
195
+ if k == 'phase':
196
+ assert v in ['train', 'test']
197
+ if k == 'stage':
198
+ assert v in ['pretrain-vision', 'pretrain-language',
199
+ 'train-semi-super', 'train-super']
200
+ self.__setattr__(f'{prefix}{k}', v)
201
+
202
+ assert os.path.exists(config_path), '%s does not exists!' % config_path
203
+ with open(config_path) as file:
204
+ config_dict = yaml.load(file, Loader=yaml.FullLoader)
205
+ with open('configs/template.yaml') as file:
206
+ default_config_dict = yaml.load(file, Loader=yaml.FullLoader)
207
+ __dict2attr(default_config_dict)
208
+ __dict2attr(config_dict)
209
+ self.global_workdir = os.path.join(self.global_workdir, self.global_name)
210
+
211
+ def __getattr__(self, item):
212
+ attr = self.__dict__.get(item)
213
+ if attr is None:
214
+ attr = dict()
215
+ prefix = f'{item}_'
216
+ for k, v in self.__dict__.items():
217
+ if k.startswith(prefix):
218
+ n = k.replace(prefix, '')
219
+ attr[n] = v
220
+ return attr if len(attr) > 0 else None
221
+ else:
222
+ return attr
223
+
224
+ def __repr__(self):
225
+ str = 'ModelConfig(\n'
226
+ for i, (k, v) in enumerate(sorted(vars(self).items())):
227
+ str += f'\t({i}): {k} = {v}\n'
228
+ str += ')'
229
+ return str
230
+
231
+ def blend_mask(image, mask, alpha=0.5, cmap='jet', color='b', color_alpha=1.0):
232
+ # normalize mask
233
+ mask = (mask-mask.min()) / (mask.max() - mask.min() + np.finfo(float).eps)
234
+ if mask.shape != image.shape:
235
+ mask = cv2.resize(mask,(image.shape[1], image.shape[0]))
236
+ # get color map
237
+ color_map = plt.get_cmap(cmap)
238
+ mask = color_map(mask)[:,:,:3]
239
+ # convert float to uint8
240
+ mask = (mask * 255).astype(dtype=np.uint8)
241
+
242
+ # set the basic color
243
+ basic_color = np.array(colors.to_rgb(color)) * 255
244
+ basic_color = np.tile(basic_color, [image.shape[0], image.shape[1], 1])
245
+ basic_color = basic_color.astype(dtype=np.uint8)
246
+ # blend with basic color
247
+ blended_img = cv2.addWeighted(image, color_alpha, basic_color, 1-color_alpha, 0)
248
+ # blend with mask
249
+ blended_img = cv2.addWeighted(blended_img, alpha, mask, 1-alpha, 0)
250
+
251
+ return blended_img
252
+
253
+ def onehot(label, depth, device=None):
254
+ """
255
+ Args:
256
+ label: shape (n1, n2, ..., )
257
+ depth: a scalar
258
+
259
+ Returns:
260
+ onehot: (n1, n2, ..., depth)
261
+ """
262
+ if not isinstance(label, torch.Tensor):
263
+ label = torch.tensor(label, device=device)
264
+ onehot = torch.zeros(label.size() + torch.Size([depth]), device=device)
265
+ onehot = onehot.scatter_(-1, label.unsqueeze(-1), 1)
266
+
267
+ return onehot
268
+
269
+ class MyDataParallel(nn.DataParallel):
270
+
271
+ def gather(self, outputs, target_device):
272
+ r"""
273
+ Gathers tensors from different GPUs on a specified device
274
+ (-1 means the CPU).
275
+ """
276
+ def gather_map(outputs):
277
+ out = outputs[0]
278
+ if isinstance(out, (str, int, float)):
279
+ return out
280
+ if isinstance(out, list) and isinstance(out[0], str):
281
+ return [o for out in outputs for o in out]
282
+ if isinstance(out, torch.Tensor):
283
+ return torch.nn.parallel._functions.Gather.apply(target_device, self.dim, *outputs)
284
+ if out is None:
285
+ return None
286
+ if isinstance(out, dict):
287
+ if not all((len(out) == len(d) for d in outputs)):
288
+ raise ValueError('All dicts must have the same number of keys')
289
+ return type(out)(((k, gather_map([d[k] for d in outputs]))
290
+ for k in out))
291
+ return type(out)(map(gather_map, zip(*outputs)))
292
+
293
+ # Recursive function calls like this create reference cycles.
294
+ # Setting the function to None clears the refcycle.
295
+ try:
296
+ res = gather_map(outputs)
297
+ finally:
298
+ gather_map = None
299
+ return res
300
+
301
+
302
+ class MyConcatDataset(ConcatDataset):
303
+ def __getattr__(self, k):
304
+ return getattr(self.datasets[0], k)