Your Name commited on
Commit
0bb6548
1 Parent(s): 74a6211
app.py CHANGED
@@ -1,122 +1,35 @@
1
- import argparse
2
- import logging
3
  import os
4
- import glob
5
- import tqdm
6
- import torch, re
7
- import PIL
8
- import cv2
9
- import numpy as np
10
- import torch.nn.functional as F
11
- from torchvision import transforms
12
- from utils import Config, Logger, CharsetMapper
13
-
14
- def get_model(config):
15
- import importlib
16
- names = config.model_name.split('.')
17
- module_name, class_name = '.'.join(names[:-1]), names[-1]
18
- cls = getattr(importlib.import_module(module_name), class_name)
19
- model = cls(config)
20
- logging.info(model)
21
- model = model.eval()
22
- return model
23
-
24
- def preprocess(img, width, height):
25
- img = cv2.resize(np.array(img), (width, height))
26
- img = transforms.ToTensor()(img).unsqueeze(0)
27
- mean = torch.tensor([0.485, 0.456, 0.406])
28
- std = torch.tensor([0.229, 0.224, 0.225])
29
- return (img-mean[...,None,None]) / std[...,None,None]
30
-
31
- def postprocess(output, charset, model_eval):
32
- def _get_output(last_output, model_eval):
33
- if isinstance(last_output, (tuple, list)):
34
- for res in last_output:
35
- if res['name'] == model_eval: output = res
36
- else: output = last_output
37
- return output
38
-
39
- def _decode(logit):
40
- """ Greed decode """
41
- out = F.softmax(logit, dim=2)
42
- pt_text, pt_scores, pt_lengths = [], [], []
43
- for o in out:
44
- text = charset.get_text(o.argmax(dim=1), padding=False, trim=False)
45
- text = text.split(charset.null_char)[0] # end at end-token
46
- pt_text.append(text)
47
- pt_scores.append(o.max(dim=1)[0])
48
- pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token
49
- return pt_text, pt_scores, pt_lengths
50
-
51
- output = _get_output(output, model_eval)
52
- logits, pt_lengths = output['logits'], output['pt_lengths']
53
- pt_text, pt_scores, pt_lengths_ = _decode(logits)
54
-
55
- return pt_text, pt_scores, pt_lengths_
56
-
57
- def load(model, file, device=None, strict=True):
58
- if device is None: device = 'cpu'
59
- elif isinstance(device, int): device = torch.device('cuda', device)
60
- assert os.path.isfile(file)
61
- state = torch.load(file, map_location=device)
62
- if set(state.keys()) == {'model', 'opt'}:
63
- state = state['model']
64
- model.load_state_dict(state, strict=strict)
65
- return model
66
 
67
 
68
- def main():
69
- parser = argparse.ArgumentParser()
70
- parser.add_argument('--config', type=str, default='configs/train_abinet.yaml',
71
- help='path to config file')
72
- parser.add_argument('--input', type=str, default='figs/test')
73
- parser.add_argument('--cuda', type=int, default=-1)
74
- parser.add_argument('--checkpoint', type=str, default='workdir/train-abinet/best-train-abinet.pth')
75
- parser.add_argument('--model_eval', type=str, default='alignment',
76
- choices=['alignment', 'vision', 'language'])
77
- args = parser.parse_args()
78
- config = Config(args.config)
79
- if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
80
- if args.model_eval is not None: config.model_eval = args.model_eval
81
- config.global_phase = 'test'
82
- config.model_vision_checkpoint, config.model_language_checkpoint = None, None
83
- device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}'
84
-
85
- Logger.init(config.global_workdir, config.global_name, config.global_phase)
86
- Logger.enable_file()
87
- logging.info(config)
88
-
89
- logging.info('Construct model.')
90
- model = get_model(config).to(device)
91
- model = load(model, config.model_checkpoint, device=device)
92
- charset = CharsetMapper(filename=config.dataset_charset_path,
93
- max_length=config.dataset_max_length + 1)
94
-
95
- if os.path.isdir(args.input):
96
- paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)]
97
- else:
98
- paths = glob.glob(os.path.expanduser(args.input))
99
- assert paths, "The input path(s) was not found"
100
- paths = sorted(paths)
101
-
102
-
103
- count = 0
104
- checks = 0
105
- print(tqdm.tqdm(paths))
106
- for path in tqdm.tqdm(paths):
107
- img = PIL.Image.open(path).convert('RGB')
108
- img = preprocess(img, config.dataset_image_width, config.dataset_image_height)
109
- img = img.to(device)
110
- res = model(img)
111
- pt_text, _, __ = postprocess(res, charset, config.model_eval)
112
- a = re.findall(r'(\d{6}).png', path)[0]
113
- # print(a)
114
- # print(pt_text[0], "Lol")
115
- # a = re.findall(r'base/(.*).pn', path)[0]
116
- checks += 1
117
- if a.lower() != pt_text[0].lower():
118
- count += 1
119
- print(f'label:{a.lower()} ||| guess:{pt_text[0]} ||| count_fails:{str(count)}/{str(checks)}')
120
 
121
- if __name__ == '__main__':
122
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ os.system('pip install --upgrade gdown')
3
+ import gdown
4
+ gdown.download(id='1z0O-bBy1z6WVV1QBBbFz8biXGl7ni--r', output='workdir.zip')
5
+ os.system('unzip workdir.zip')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
+ import glob
9
+ import gradio as gr
10
+ from demo import get_model, preprocess, postprocess, load
11
+ from utils import Config, Logger, CharsetMapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ config = Config('configs/train_abinet.yaml')
14
+ config.model_vision_checkpoint = None
15
+ model = get_model(config)
16
+ model = load(model, 'workdir/train-abinet/best-train-abinet.pth')
17
+ charset = CharsetMapper(filename=config.dataset_charset_path, max_length=config.dataset_max_length + 1)
18
+
19
+ def process_image(image):
20
+ img = image.convert('RGB')
21
+ img = preprocess(img, config.dataset_image_width, config.dataset_image_height)
22
+ res = model(img)
23
+ return postprocess(res, charset, 'alignment')[0][0]
24
+
25
+ title = "Made with ABINet"
26
+ description = "I hate captchas"
27
+
28
+ iface = gr.Interface(fn=process_image,
29
+ inputs=gr.inputs.Image(type="pil"),
30
+ outputs=gr.outputs.Textbox(),
31
+ title=title,
32
+ description=description,
33
+ examples=glob.glob('figs_captchas/*.jpg'))
34
+
35
+ iface.launch(debug=True)
configs/template.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ name: exp
3
+ phase: train
4
+ stage: pretrain-vision
5
+ workdir: /tmp/workdir
6
+ seed: ~
7
+
8
+ dataset:
9
+ train: {
10
+ roots: ['data/training/MJ/MJ_train/',
11
+ 'data/training/MJ/MJ_test/',
12
+ 'data/training/MJ/MJ_valid/',
13
+ 'data/training/ST'],
14
+ batch_size: 128
15
+ }
16
+ test: {
17
+ roots: ['data/evaluation/IIIT5k_3000',
18
+ 'data/evaluation/SVT',
19
+ 'data/evaluation/SVTP',
20
+ 'data/evaluation/IC13_857',
21
+ 'data/evaluation/IC15_1811',
22
+ 'data/evaluation/CUTE80'],
23
+ batch_size: 128
24
+ }
25
+ charset_path: data/charset_36.txt
26
+ num_workers: 4
27
+ max_length: 25 # 30
28
+ image_height: 32
29
+ image_width: 128
30
+ case_sensitive: False
31
+ eval_case_sensitive: False
32
+ data_aug: True
33
+ multiscales: False
34
+ pin_memory: True
35
+ smooth_label: False
36
+ smooth_factor: 0.1
37
+ one_hot_y: True
38
+ use_sm: False
39
+
40
+ training:
41
+ epochs: 6
42
+ show_iters: 50
43
+ eval_iters: 3000
44
+ save_iters: 20000
45
+ start_iters: 0
46
+ stats_iters: 100000
47
+
48
+ optimizer:
49
+ type: Adadelta # Adadelta, Adam
50
+ true_wd: False
51
+ wd: 0. # 0.001
52
+ bn_wd: False
53
+ args: {
54
+ # betas: !!python/tuple [0.9, 0.99], # betas=(0.9,0.99) for AdamW
55
+ # betas: !!python/tuple [0.9, 0.999], # for default Adam
56
+ }
57
+ clip_grad: 20
58
+ lr: [1.0, 1.0, 1.0] # lr: [0.005, 0.005, 0.005]
59
+ scheduler: {
60
+ periods: [3, 2, 1],
61
+ gamma: 0.1,
62
+ }
63
+
64
+ model:
65
+ name: 'modules.model_abinet.ABINetModel'
66
+ checkpoint: ~
67
+ strict: True
configs/train_abinet.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ name: train-abinet
3
+ phase: train
4
+ stage: train-super
5
+ workdir: workdir
6
+ seed: ~
7
+
8
+ dataset:
9
+ train: {
10
+ roots: [
11
+ 'output_tbell_dataset/',
12
+ # 'data/training/MJ/MJ_train/',
13
+ # 'data/training/MJ/MJ_test/',
14
+ # 'data/training/MJ/MJ_valid/',
15
+ # 'data/training/ST'
16
+ ],
17
+ batch_size: 50
18
+ }
19
+ test: {
20
+ roots: [
21
+ 'output_tbell_dataset/'
22
+ ],
23
+ batch_size: 50
24
+ }
25
+ data_aug: True
26
+ multiscales: False
27
+ num_workers: 5
28
+
29
+ training:
30
+ epochs: 50
31
+ show_iters: 200
32
+ eval_iters: 300
33
+ # save_iters: 3000
34
+
35
+ optimizer:
36
+ type: Adamax
37
+ true_wd: False
38
+ wd: 0.0
39
+ bn_wd: False
40
+ clip_grad: 20
41
+ lr: 0.0001
42
+ args: {
43
+ betas: !!python/tuple [0.9, 0.999], # for default Adam
44
+ }
45
+ scheduler: {
46
+ periods: [6, 4],
47
+ gamma: 0.1,
48
+ }
49
+
50
+ model:
51
+ name: 'modules.model_abinet_iter.ABINetIterModel'
52
+ iter_size: 5
53
+ ensemble: ''
54
+ use_vision: True
55
+ vision: {
56
+ checkpoint: workdir/pretrain-vision-model/best-pretrain-vision-model.pth,
57
+ loss_weight: 1.,
58
+ attention: 'position',
59
+ backbone: 'transformer',
60
+ backbone_ln: 3,
61
+ }
62
+ # language: {
63
+ # checkpoint: workdir/pretrain-language-model/pretrain-language-model.pth,
64
+ # num_layers: 4,
65
+ # loss_weight: 1.,
66
+ # detach: True,
67
+ # use_self_attn: False
68
+ # }
69
+ alignment: {
70
+ loss_weight: 1.,
71
+ }
data/charset_36.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0 a
2
+ 1 b
3
+ 2 c
4
+ 3 d
5
+ 4 e
6
+ 5 f
7
+ 6 g
8
+ 7 h
9
+ 8 i
10
+ 9 j
11
+ 10 k
12
+ 11 l
13
+ 12 m
14
+ 13 n
15
+ 14 o
16
+ 15 p
17
+ 16 q
18
+ 17 r
19
+ 18 s
20
+ 19 t
21
+ 20 u
22
+ 21 v
23
+ 22 w
24
+ 23 x
25
+ 24 y
26
+ 25 z
27
+ 26 1
28
+ 27 2
29
+ 28 3
30
+ 29 4
31
+ 30 5
32
+ 31 6
33
+ 32 7
34
+ 33 8
35
+ 34 9
36
+ 35 0
data/charset_62.txt ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0 0
2
+ 1 1
3
+ 2 2
4
+ 3 3
5
+ 4 4
6
+ 5 5
7
+ 6 6
8
+ 7 7
9
+ 8 8
10
+ 9 9
11
+ 10 A
12
+ 11 B
13
+ 12 C
14
+ 13 D
15
+ 14 E
16
+ 15 F
17
+ 16 G
18
+ 17 H
19
+ 18 I
20
+ 19 J
21
+ 20 K
22
+ 21 L
23
+ 22 M
24
+ 23 N
25
+ 24 O
26
+ 25 P
27
+ 26 Q
28
+ 27 R
29
+ 28 S
30
+ 29 T
31
+ 30 U
32
+ 31 V
33
+ 32 W
34
+ 33 X
35
+ 34 Y
36
+ 35 Z
37
+ 36 a
38
+ 37 b
39
+ 38 c
40
+ 39 d
41
+ 40 e
42
+ 41 f
43
+ 42 g
44
+ 43 h
45
+ 44 i
46
+ 45 j
47
+ 46 k
48
+ 47 l
49
+ 48 m
50
+ 49 n
51
+ 50 o
52
+ 51 p
53
+ 52 q
54
+ 53 r
55
+ 54 s
56
+ 55 t
57
+ 56 u
58
+ 57 v
59
+ 58 w
60
+ 59 x
61
+ 60 y
62
+ 61 z
figs_captchas/show (3).png ADDED
figs_captchas/show (4).png ADDED
figs_captchas/show (5).png ADDED
figs_captchas/show (6).png ADDED
figs_captchas/show (7).png ADDED
figs_captchas/show (8).png ADDED
modules/attention.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .transformer import PositionalEncoding
4
+
5
+ class Attention(nn.Module):
6
+ def __init__(self, in_channels=512, max_length=25, n_feature=256):
7
+ super().__init__()
8
+ self.max_length = max_length
9
+
10
+ self.f0_embedding = nn.Embedding(max_length, in_channels)
11
+ self.w0 = nn.Linear(max_length, n_feature)
12
+ self.wv = nn.Linear(in_channels, in_channels)
13
+ self.we = nn.Linear(in_channels, max_length)
14
+
15
+ self.active = nn.Tanh()
16
+ self.softmax = nn.Softmax(dim=2)
17
+
18
+ def forward(self, enc_output):
19
+ enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2)
20
+ reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device)
21
+ reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S)
22
+ reading_order_embed = self.f0_embedding(reading_order) # b,25,512
23
+
24
+ t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256
25
+ t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512
26
+
27
+ attn = self.we(t) # b,256,25
28
+ attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256
29
+ g_output = torch.bmm(attn, enc_output) # b,25,512
30
+ return g_output, attn.view(*attn.shape[:2], 8, 32)
31
+
32
+
33
+ def encoder_layer(in_c, out_c, k=3, s=2, p=1):
34
+ return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
35
+ nn.BatchNorm2d(out_c),
36
+ nn.ReLU(True))
37
+
38
+ def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None):
39
+ align_corners = None if mode=='nearest' else True
40
+ return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor,
41
+ mode=mode, align_corners=align_corners),
42
+ nn.Conv2d(in_c, out_c, k, s, p),
43
+ nn.BatchNorm2d(out_c),
44
+ nn.ReLU(True))
45
+
46
+
47
+ class PositionAttention(nn.Module):
48
+ def __init__(self, max_length, in_channels=512, num_channels=64,
49
+ h=8, w=32, mode='nearest', **kwargs):
50
+ super().__init__()
51
+ self.max_length = max_length
52
+ self.k_encoder = nn.Sequential(
53
+ encoder_layer(in_channels, num_channels, s=(1, 2)),
54
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
55
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
56
+ encoder_layer(num_channels, num_channels, s=(2, 2))
57
+ )
58
+ self.k_decoder = nn.Sequential(
59
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
60
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
61
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
62
+ decoder_layer(num_channels, in_channels, size=(h, w), mode=mode)
63
+ )
64
+
65
+ self.pos_encoder = PositionalEncoding(in_channels, dropout=0, max_len=max_length)
66
+ self.project = nn.Linear(in_channels, in_channels)
67
+
68
+ def forward(self, x):
69
+ N, E, H, W = x.size()
70
+ k, v = x, x # (N, E, H, W)
71
+
72
+ # calculate key vector
73
+ features = []
74
+ for i in range(0, len(self.k_encoder)):
75
+ k = self.k_encoder[i](k)
76
+ features.append(k)
77
+ for i in range(0, len(self.k_decoder) - 1):
78
+ k = self.k_decoder[i](k)
79
+ k = k + features[len(self.k_decoder) - 2 - i]
80
+ k = self.k_decoder[-1](k)
81
+
82
+ # calculate query vector
83
+ # TODO q=f(q,k)
84
+ zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E)
85
+ q = self.pos_encoder(zeros) # (T, N, E)
86
+ q = q.permute(1, 0, 2) # (N, T, E)
87
+ q = self.project(q) # (N, T, E)
88
+
89
+ # calculate attention
90
+ attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
91
+ attn_scores = attn_scores / (E ** 0.5)
92
+ attn_scores = torch.softmax(attn_scores, dim=-1)
93
+
94
+ v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
95
+ attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
96
+
97
+ return attn_vecs, attn_scores.view(N, -1, H, W)
modules/backbone.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from modules.model import _default_tfmer_cfg
6
+ from modules.resnet import resnet45
7
+ from modules.transformer import (PositionalEncoding,
8
+ TransformerEncoder,
9
+ TransformerEncoderLayer)
10
+
11
+
12
+ class ResTranformer(nn.Module):
13
+ def __init__(self, config):
14
+ super().__init__()
15
+ self.resnet = resnet45()
16
+
17
+ self.d_model = ifnone(config.model_vision_d_model, _default_tfmer_cfg['d_model'])
18
+ nhead = ifnone(config.model_vision_nhead, _default_tfmer_cfg['nhead'])
19
+ d_inner = ifnone(config.model_vision_d_inner, _default_tfmer_cfg['d_inner'])
20
+ dropout = ifnone(config.model_vision_dropout, _default_tfmer_cfg['dropout'])
21
+ activation = ifnone(config.model_vision_activation, _default_tfmer_cfg['activation'])
22
+ num_layers = ifnone(config.model_vision_backbone_ln, 2)
23
+
24
+ self.pos_encoder = PositionalEncoding(self.d_model, max_len=8*32)
25
+ encoder_layer = TransformerEncoderLayer(d_model=self.d_model, nhead=nhead,
26
+ dim_feedforward=d_inner, dropout=dropout, activation=activation)
27
+ self.transformer = TransformerEncoder(encoder_layer, num_layers)
28
+
29
+ def forward(self, images):
30
+ feature = self.resnet(images)
31
+ n, c, h, w = feature.shape
32
+ feature = feature.view(n, c, -1).permute(2, 0, 1)
33
+ feature = self.pos_encoder(feature)
34
+ feature = self.transformer(feature)
35
+ feature = feature.permute(1, 2, 0).view(n, c, h, w)
36
+ return feature
modules/model.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from utils import CharsetMapper
5
+
6
+
7
+ _default_tfmer_cfg = dict(d_model=512, nhead=8, d_inner=2048, # 1024
8
+ dropout=0.1, activation='relu')
9
+
10
+ class Model(nn.Module):
11
+
12
+ def __init__(self, config):
13
+ super().__init__()
14
+ self.max_length = config.dataset_max_length + 1
15
+ self.charset = CharsetMapper(config.dataset_charset_path, max_length=self.max_length)
16
+
17
+ def load(self, source, device=None, strict=True):
18
+ state = torch.load(source, map_location=device)
19
+ self.load_state_dict(state['model'], strict=strict)
20
+
21
+ def _get_length(self, logit, dim=-1):
22
+ """ Greed decoder to obtain length from logit"""
23
+ out = (logit.argmax(dim=-1) == self.charset.null_label)
24
+ abn = out.any(dim)
25
+ out = ((out.cumsum(dim) == 1) & out).max(dim)[1]
26
+ out = out + 1 # additional end token
27
+ out = torch.where(abn, out, out.new_tensor(logit.shape[1]))
28
+ return out
29
+
30
+ @staticmethod
31
+ def _get_padding_mask(length, max_length):
32
+ length = length.unsqueeze(-1)
33
+ grid = torch.arange(0, max_length, device=length.device).unsqueeze(0)
34
+ return grid >= length
35
+
36
+ @staticmethod
37
+ def _get_square_subsequent_mask(sz, device, diagonal=0, fw=True):
38
+ r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
39
+ Unmasked positions are filled with float(0.0).
40
+ """
41
+ mask = (torch.triu(torch.ones(sz, sz, device=device), diagonal=diagonal) == 1)
42
+ if fw: mask = mask.transpose(0, 1)
43
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
44
+ return mask
45
+
46
+ @staticmethod
47
+ def _get_location_mask(sz, device=None):
48
+ mask = torch.eye(sz, device=device)
49
+ mask = mask.float().masked_fill(mask == 1, float('-inf'))
50
+ return mask
modules/model_abinet.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from .model_vision import BaseVision
6
+ from .model_language import BCNLanguage
7
+ from .model_alignment import BaseAlignment
8
+
9
+
10
+ class ABINetModel(nn.Module):
11
+ def __init__(self, config):
12
+ super().__init__()
13
+ self.use_alignment = ifnone(config.model_use_alignment, True)
14
+ self.max_length = config.dataset_max_length + 1 # additional stop token
15
+ self.vision = BaseVision(config)
16
+ self.language = BCNLanguage(config)
17
+ if self.use_alignment: self.alignment = BaseAlignment(config)
18
+
19
+ def forward(self, images, *args):
20
+ v_res = self.vision(images)
21
+ v_tokens = torch.softmax(v_res['logits'], dim=-1)
22
+ v_lengths = v_res['pt_lengths'].clamp_(2, self.max_length) # TODO:move to langauge model
23
+
24
+ l_res = self.language(v_tokens, v_lengths)
25
+ if not self.use_alignment:
26
+ return l_res, v_res
27
+ l_feature, v_feature = l_res['feature'], v_res['feature']
28
+
29
+ a_res = self.alignment(l_feature, v_feature)
30
+ return a_res, l_res, v_res
modules/model_abinet_iter.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from .model_vision import BaseVision
6
+ from .model_language import BCNLanguage
7
+ from .model_alignment import BaseAlignment
8
+
9
+
10
+ class ABINetIterModel(nn.Module):
11
+ def __init__(self, config):
12
+ super().__init__()
13
+ self.iter_size = ifnone(config.model_iter_size, 1)
14
+ self.max_length = config.dataset_max_length + 1 # additional stop token
15
+ self.vision = BaseVision(config)
16
+ self.language = BCNLanguage(config)
17
+ self.alignment = BaseAlignment(config)
18
+
19
+ def forward(self, images, *args):
20
+ v_res = self.vision(images)
21
+ a_res = v_res
22
+ all_l_res, all_a_res = [], []
23
+ for _ in range(self.iter_size):
24
+ tokens = torch.softmax(a_res['logits'], dim=-1)
25
+ lengths = a_res['pt_lengths']
26
+ lengths.clamp_(2, self.max_length) # TODO:move to langauge model
27
+ l_res = self.language(tokens, lengths)
28
+ all_l_res.append(l_res)
29
+ a_res = self.alignment(l_res['feature'], v_res['feature'])
30
+ all_a_res.append(a_res)
31
+ if self.training:
32
+ return all_a_res, all_l_res, v_res
33
+ else:
34
+ return a_res, all_l_res[-1], v_res
modules/model_alignment.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from modules.model import Model, _default_tfmer_cfg
6
+
7
+
8
+ class BaseAlignment(Model):
9
+ def __init__(self, config):
10
+ super().__init__(config)
11
+ d_model = ifnone(config.model_alignment_d_model, _default_tfmer_cfg['d_model'])
12
+
13
+ self.loss_weight = ifnone(config.model_alignment_loss_weight, 1.0)
14
+ self.max_length = config.dataset_max_length + 1 # additional stop token
15
+ self.w_att = nn.Linear(2 * d_model, d_model)
16
+ self.cls = nn.Linear(d_model, self.charset.num_classes)
17
+
18
+ def forward(self, l_feature, v_feature):
19
+ """
20
+ Args:
21
+ l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
22
+ v_feature: (N, T, E) shape the same as l_feature
23
+ l_lengths: (N,)
24
+ v_lengths: (N,)
25
+ """
26
+ f = torch.cat((l_feature, v_feature), dim=2)
27
+ f_att = torch.sigmoid(self.w_att(f))
28
+ output = f_att * v_feature + (1 - f_att) * l_feature
29
+
30
+ logits = self.cls(output) # (N, T, C)
31
+ pt_lengths = self._get_length(logits)
32
+
33
+ return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight':self.loss_weight,
34
+ 'name': 'alignment'}
modules/model_language.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from modules.model import _default_tfmer_cfg
6
+ from modules.model import Model
7
+ from modules.transformer import (PositionalEncoding,
8
+ TransformerDecoder,
9
+ TransformerDecoderLayer)
10
+
11
+
12
+ class BCNLanguage(Model):
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+ d_model = ifnone(config.model_language_d_model, _default_tfmer_cfg['d_model'])
16
+ nhead = ifnone(config.model_language_nhead, _default_tfmer_cfg['nhead'])
17
+ d_inner = ifnone(config.model_language_d_inner, _default_tfmer_cfg['d_inner'])
18
+ dropout = ifnone(config.model_language_dropout, _default_tfmer_cfg['dropout'])
19
+ activation = ifnone(config.model_language_activation, _default_tfmer_cfg['activation'])
20
+ num_layers = ifnone(config.model_language_num_layers, 4)
21
+ self.d_model = d_model
22
+ self.detach = ifnone(config.model_language_detach, True)
23
+ self.use_self_attn = ifnone(config.model_language_use_self_attn, False)
24
+ self.loss_weight = ifnone(config.model_language_loss_weight, 1.0)
25
+ self.max_length = config.dataset_max_length + 1 # additional stop token
26
+ self.debug = ifnone(config.global_debug, False)
27
+
28
+ self.proj = nn.Linear(self.charset.num_classes, d_model, False)
29
+ self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length)
30
+ self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length)
31
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout,
32
+ activation, self_attn=self.use_self_attn, debug=self.debug)
33
+ self.model = TransformerDecoder(decoder_layer, num_layers)
34
+
35
+ self.cls = nn.Linear(d_model, self.charset.num_classes)
36
+
37
+ if config.model_language_checkpoint is not None:
38
+ logging.info(f'Read language model from {config.model_language_checkpoint}.')
39
+ self.load(config.model_language_checkpoint)
40
+
41
+ def forward(self, tokens, lengths):
42
+ """
43
+ Args:
44
+ tokens: (N, T, C) where T is length, N is batch size and C is classes number
45
+ lengths: (N,)
46
+ """
47
+ if self.detach: tokens = tokens.detach()
48
+ embed = self.proj(tokens) # (N, T, E)
49
+ embed = embed.permute(1, 0, 2) # (T, N, E)
50
+ embed = self.token_encoder(embed) # (T, N, E)
51
+ padding_mask = self._get_padding_mask(lengths, self.max_length)
52
+
53
+ zeros = embed.new_zeros(*embed.shape)
54
+ qeury = self.pos_encoder(zeros)
55
+ location_mask = self._get_location_mask(self.max_length, tokens.device)
56
+ output = self.model(qeury, embed,
57
+ tgt_key_padding_mask=padding_mask,
58
+ memory_mask=location_mask,
59
+ memory_key_padding_mask=padding_mask) # (T, N, E)
60
+ output = output.permute(1, 0, 2) # (N, T, E)
61
+
62
+ logits = self.cls(output) # (N, T, C)
63
+ pt_lengths = self._get_length(logits)
64
+
65
+ res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths,
66
+ 'loss_weight':self.loss_weight, 'name': 'language'}
67
+ return res
modules/model_vision.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from modules.attention import *
6
+ from modules.backbone import ResTranformer
7
+ from modules.model import Model
8
+ from modules.resnet import resnet45
9
+
10
+
11
+ class BaseVision(Model):
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+ self.loss_weight = ifnone(config.model_vision_loss_weight, 1.0)
15
+ self.out_channels = ifnone(config.model_vision_d_model, 512)
16
+
17
+ if config.model_vision_backbone == 'transformer':
18
+ self.backbone = ResTranformer(config)
19
+ else: self.backbone = resnet45()
20
+
21
+ if config.model_vision_attention == 'position':
22
+ mode = ifnone(config.model_vision_attention_mode, 'nearest')
23
+ self.attention = PositionAttention(
24
+ max_length=config.dataset_max_length + 1, # additional stop token
25
+ mode=mode,
26
+ )
27
+ elif config.model_vision_attention == 'attention':
28
+ self.attention = Attention(
29
+ max_length=config.dataset_max_length + 1, # additional stop token
30
+ n_feature=8*32,
31
+ )
32
+ else:
33
+ raise Exception(f'{config.model_vision_attention} is not valid.')
34
+ self.cls = nn.Linear(self.out_channels, self.charset.num_classes)
35
+
36
+ if config.model_vision_checkpoint is not None:
37
+ logging.info(f'Read vision model from {config.model_vision_checkpoint}.')
38
+ self.load(config.model_vision_checkpoint)
39
+
40
+ def forward(self, images, *args):
41
+ features = self.backbone(images) # (N, E, H, W)
42
+ attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W)
43
+ logits = self.cls(attn_vecs) # (N, T, C)
44
+ pt_lengths = self._get_length(logits)
45
+
46
+ return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths,
47
+ 'attn_scores': attn_scores, 'loss_weight':self.loss_weight, 'name': 'vision'}
modules/resnet.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.model_zoo as model_zoo
6
+
7
+
8
+ def conv1x1(in_planes, out_planes, stride=1):
9
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
10
+
11
+
12
+ def conv3x3(in_planes, out_planes, stride=1):
13
+ "3x3 convolution with padding"
14
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
15
+ padding=1, bias=False)
16
+
17
+
18
+ class BasicBlock(nn.Module):
19
+ expansion = 1
20
+
21
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv1x1(inplanes, planes)
24
+ self.bn1 = nn.BatchNorm2d(planes)
25
+ self.relu = nn.ReLU(inplace=True)
26
+ self.conv2 = conv3x3(planes, planes, stride)
27
+ self.bn2 = nn.BatchNorm2d(planes)
28
+ self.downsample = downsample
29
+ self.stride = stride
30
+
31
+ def forward(self, x):
32
+ residual = x
33
+
34
+ out = self.conv1(x)
35
+ out = self.bn1(out)
36
+ out = self.relu(out)
37
+
38
+ out = self.conv2(out)
39
+ out = self.bn2(out)
40
+
41
+ if self.downsample is not None:
42
+ residual = self.downsample(x)
43
+
44
+ out += residual
45
+ out = self.relu(out)
46
+
47
+ return out
48
+
49
+
50
+ class ResNet(nn.Module):
51
+
52
+ def __init__(self, block, layers):
53
+ self.inplanes = 32
54
+ super(ResNet, self).__init__()
55
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1,
56
+ bias=False)
57
+ self.bn1 = nn.BatchNorm2d(32)
58
+ self.relu = nn.ReLU(inplace=True)
59
+
60
+ self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
61
+ self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
62
+ self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
63
+ self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
64
+ self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
65
+
66
+ for m in self.modules():
67
+ if isinstance(m, nn.Conv2d):
68
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
69
+ m.weight.data.normal_(0, math.sqrt(3. / n))
70
+ elif isinstance(m, nn.BatchNorm2d):
71
+ m.weight.data.fill_(1)
72
+ m.bias.data.zero_()
73
+
74
+ def _make_layer(self, block, planes, blocks, stride=1):
75
+ downsample = None
76
+ if stride != 1 or self.inplanes != planes * block.expansion:
77
+ downsample = nn.Sequential(
78
+ nn.Conv2d(self.inplanes, planes * block.expansion,
79
+ kernel_size=1, stride=stride, bias=False),
80
+ nn.BatchNorm2d(planes * block.expansion),
81
+ )
82
+
83
+ layers = []
84
+ layers.append(block(self.inplanes, planes, stride, downsample))
85
+ self.inplanes = planes * block.expansion
86
+ for i in range(1, blocks):
87
+ layers.append(block(self.inplanes, planes))
88
+
89
+ return nn.Sequential(*layers)
90
+
91
+ def forward(self, x):
92
+ x = self.conv1(x)
93
+ x = self.bn1(x)
94
+ x = self.relu(x)
95
+ x = self.layer1(x)
96
+ x = self.layer2(x)
97
+ x = self.layer3(x)
98
+ x = self.layer4(x)
99
+ x = self.layer5(x)
100
+ return x
101
+
102
+
103
+ def resnet45():
104
+ return ResNet(BasicBlock, [3, 4, 6, 6, 3])
modules/transformer.py ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch 1.5.0
2
+ import copy
3
+ import math
4
+ import warnings
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import Tensor
10
+ from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList, Parameter
11
+ from torch.nn import functional as F
12
+ from torch.nn.init import constant_, xavier_uniform_
13
+
14
+
15
+ def multi_head_attention_forward(query, # type: Tensor
16
+ key, # type: Tensor
17
+ value, # type: Tensor
18
+ embed_dim_to_check, # type: int
19
+ num_heads, # type: int
20
+ in_proj_weight, # type: Tensor
21
+ in_proj_bias, # type: Tensor
22
+ bias_k, # type: Optional[Tensor]
23
+ bias_v, # type: Optional[Tensor]
24
+ add_zero_attn, # type: bool
25
+ dropout_p, # type: float
26
+ out_proj_weight, # type: Tensor
27
+ out_proj_bias, # type: Tensor
28
+ training=True, # type: bool
29
+ key_padding_mask=None, # type: Optional[Tensor]
30
+ need_weights=True, # type: bool
31
+ attn_mask=None, # type: Optional[Tensor]
32
+ use_separate_proj_weight=False, # type: bool
33
+ q_proj_weight=None, # type: Optional[Tensor]
34
+ k_proj_weight=None, # type: Optional[Tensor]
35
+ v_proj_weight=None, # type: Optional[Tensor]
36
+ static_k=None, # type: Optional[Tensor]
37
+ static_v=None # type: Optional[Tensor]
38
+ ):
39
+ # type: (...) -> Tuple[Tensor, Optional[Tensor]]
40
+ r"""
41
+ Args:
42
+ query, key, value: map a query and a set of key-value pairs to an output.
43
+ See "Attention Is All You Need" for more details.
44
+ embed_dim_to_check: total dimension of the model.
45
+ num_heads: parallel attention heads.
46
+ in_proj_weight, in_proj_bias: input projection weight and bias.
47
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
48
+ add_zero_attn: add a new batch of zeros to the key and
49
+ value sequences at dim=1.
50
+ dropout_p: probability of an element to be zeroed.
51
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
52
+ training: apply dropout if is ``True``.
53
+ key_padding_mask: if provided, specified padding elements in the key will
54
+ be ignored by the attention. This is an binary mask. When the value is True,
55
+ the corresponding value on the attention layer will be filled with -inf.
56
+ need_weights: output attn_output_weights.
57
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
58
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
59
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
60
+ and value in different forms. If false, in_proj_weight will be used, which is
61
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
62
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
63
+ static_k, static_v: static key and value used for attention operators.
64
+ Shape:
65
+ Inputs:
66
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
67
+ the embedding dimension.
68
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
69
+ the embedding dimension.
70
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
71
+ the embedding dimension.
72
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
73
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
74
+ will be unchanged. If a BoolTensor is provided, the positions with the
75
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
76
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
77
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
78
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
79
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
80
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
81
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
82
+ is provided, it will be added to the attention weight.
83
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
84
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
85
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
86
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
87
+ Outputs:
88
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
89
+ E is the embedding dimension.
90
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
91
+ L is the target sequence length, S is the source sequence length.
92
+ """
93
+ # if not torch.jit.is_scripting():
94
+ # tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
95
+ # out_proj_weight, out_proj_bias)
96
+ # if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
97
+ # return handle_torch_function(
98
+ # multi_head_attention_forward, tens_ops, query, key, value,
99
+ # embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
100
+ # bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
101
+ # out_proj_bias, training=training, key_padding_mask=key_padding_mask,
102
+ # need_weights=need_weights, attn_mask=attn_mask,
103
+ # use_separate_proj_weight=use_separate_proj_weight,
104
+ # q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
105
+ # v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
106
+ tgt_len, bsz, embed_dim = query.size()
107
+ assert embed_dim == embed_dim_to_check
108
+ assert key.size() == value.size()
109
+
110
+ head_dim = embed_dim // num_heads
111
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
112
+ scaling = float(head_dim) ** -0.5
113
+
114
+ if not use_separate_proj_weight:
115
+ if torch.equal(query, key) and torch.equal(key, value):
116
+ # self-attention
117
+ q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
118
+
119
+ elif torch.equal(key, value):
120
+ # encoder-decoder attention
121
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
122
+ _b = in_proj_bias
123
+ _start = 0
124
+ _end = embed_dim
125
+ _w = in_proj_weight[_start:_end, :]
126
+ if _b is not None:
127
+ _b = _b[_start:_end]
128
+ q = F.linear(query, _w, _b)
129
+
130
+ if key is None:
131
+ assert value is None
132
+ k = None
133
+ v = None
134
+ else:
135
+
136
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
137
+ _b = in_proj_bias
138
+ _start = embed_dim
139
+ _end = None
140
+ _w = in_proj_weight[_start:, :]
141
+ if _b is not None:
142
+ _b = _b[_start:]
143
+ k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
144
+
145
+ else:
146
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
147
+ _b = in_proj_bias
148
+ _start = 0
149
+ _end = embed_dim
150
+ _w = in_proj_weight[_start:_end, :]
151
+ if _b is not None:
152
+ _b = _b[_start:_end]
153
+ q = F.linear(query, _w, _b)
154
+
155
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
156
+ _b = in_proj_bias
157
+ _start = embed_dim
158
+ _end = embed_dim * 2
159
+ _w = in_proj_weight[_start:_end, :]
160
+ if _b is not None:
161
+ _b = _b[_start:_end]
162
+ k = F.linear(key, _w, _b)
163
+
164
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
165
+ _b = in_proj_bias
166
+ _start = embed_dim * 2
167
+ _end = None
168
+ _w = in_proj_weight[_start:, :]
169
+ if _b is not None:
170
+ _b = _b[_start:]
171
+ v = F.linear(value, _w, _b)
172
+ else:
173
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
174
+ len1, len2 = q_proj_weight_non_opt.size()
175
+ assert len1 == embed_dim and len2 == query.size(-1)
176
+
177
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
178
+ len1, len2 = k_proj_weight_non_opt.size()
179
+ assert len1 == embed_dim and len2 == key.size(-1)
180
+
181
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
182
+ len1, len2 = v_proj_weight_non_opt.size()
183
+ assert len1 == embed_dim and len2 == value.size(-1)
184
+
185
+ if in_proj_bias is not None:
186
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
187
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
188
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
189
+ else:
190
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
191
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
192
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
193
+ q = q * scaling
194
+
195
+ if attn_mask is not None:
196
+ assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
197
+ attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
198
+ 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
199
+ if attn_mask.dtype == torch.uint8:
200
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
201
+ attn_mask = attn_mask.to(torch.bool)
202
+
203
+ if attn_mask.dim() == 2:
204
+ attn_mask = attn_mask.unsqueeze(0)
205
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
206
+ raise RuntimeError('The size of the 2D attn_mask is not correct.')
207
+ elif attn_mask.dim() == 3:
208
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
209
+ raise RuntimeError('The size of the 3D attn_mask is not correct.')
210
+ else:
211
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
212
+ # attn_mask's dim is 3 now.
213
+
214
+ # # convert ByteTensor key_padding_mask to bool
215
+ # if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
216
+ # warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
217
+ # key_padding_mask = key_padding_mask.to(torch.bool)
218
+
219
+ if bias_k is not None and bias_v is not None:
220
+ if static_k is None and static_v is None:
221
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
222
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
223
+ if attn_mask is not None:
224
+ attn_mask = pad(attn_mask, (0, 1))
225
+ if key_padding_mask is not None:
226
+ key_padding_mask = pad(key_padding_mask, (0, 1))
227
+ else:
228
+ assert static_k is None, "bias cannot be added to static key."
229
+ assert static_v is None, "bias cannot be added to static value."
230
+ else:
231
+ assert bias_k is None
232
+ assert bias_v is None
233
+
234
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
235
+ if k is not None:
236
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
237
+ if v is not None:
238
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
239
+
240
+ if static_k is not None:
241
+ assert static_k.size(0) == bsz * num_heads
242
+ assert static_k.size(2) == head_dim
243
+ k = static_k
244
+
245
+ if static_v is not None:
246
+ assert static_v.size(0) == bsz * num_heads
247
+ assert static_v.size(2) == head_dim
248
+ v = static_v
249
+
250
+ src_len = k.size(1)
251
+
252
+ if key_padding_mask is not None:
253
+ assert key_padding_mask.size(0) == bsz
254
+ assert key_padding_mask.size(1) == src_len
255
+
256
+ if add_zero_attn:
257
+ src_len += 1
258
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
259
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
260
+ if attn_mask is not None:
261
+ attn_mask = pad(attn_mask, (0, 1))
262
+ if key_padding_mask is not None:
263
+ key_padding_mask = pad(key_padding_mask, (0, 1))
264
+
265
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
266
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
267
+
268
+ if attn_mask is not None:
269
+ if attn_mask.dtype == torch.bool:
270
+ attn_output_weights.masked_fill_(attn_mask, float('-inf'))
271
+ else:
272
+ attn_output_weights += attn_mask
273
+
274
+
275
+ if key_padding_mask is not None:
276
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
277
+ attn_output_weights = attn_output_weights.masked_fill(
278
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
279
+ float('-inf'),
280
+ )
281
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
282
+
283
+ attn_output_weights = F.softmax(
284
+ attn_output_weights, dim=-1)
285
+ attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
286
+
287
+ attn_output = torch.bmm(attn_output_weights, v)
288
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
289
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
290
+ attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
291
+
292
+ if need_weights:
293
+ # average attention weights over heads
294
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
295
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
296
+ else:
297
+ return attn_output, None
298
+
299
+ class MultiheadAttention(Module):
300
+ r"""Allows the model to jointly attend to information
301
+ from different representation subspaces.
302
+ See reference: Attention Is All You Need
303
+ .. math::
304
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
305
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
306
+ Args:
307
+ embed_dim: total dimension of the model.
308
+ num_heads: parallel attention heads.
309
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
310
+ bias: add bias as module parameter. Default: True.
311
+ add_bias_kv: add bias to the key and value sequences at dim=0.
312
+ add_zero_attn: add a new batch of zeros to the key and
313
+ value sequences at dim=1.
314
+ kdim: total number of features in key. Default: None.
315
+ vdim: total number of features in value. Default: None.
316
+ Note: if kdim and vdim are None, they will be set to embed_dim such that
317
+ query, key, and value have the same number of features.
318
+ Examples::
319
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
320
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
321
+ """
322
+ # __annotations__ = {
323
+ # 'bias_k': torch._jit_internal.Optional[torch.Tensor],
324
+ # 'bias_v': torch._jit_internal.Optional[torch.Tensor],
325
+ # }
326
+ __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']
327
+
328
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
329
+ super(MultiheadAttention, self).__init__()
330
+ self.embed_dim = embed_dim
331
+ self.kdim = kdim if kdim is not None else embed_dim
332
+ self.vdim = vdim if vdim is not None else embed_dim
333
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
334
+
335
+ self.num_heads = num_heads
336
+ self.dropout = dropout
337
+ self.head_dim = embed_dim // num_heads
338
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
339
+
340
+ if self._qkv_same_embed_dim is False:
341
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
342
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
343
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
344
+ self.register_parameter('in_proj_weight', None)
345
+ else:
346
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
347
+ self.register_parameter('q_proj_weight', None)
348
+ self.register_parameter('k_proj_weight', None)
349
+ self.register_parameter('v_proj_weight', None)
350
+
351
+ if bias:
352
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
353
+ else:
354
+ self.register_parameter('in_proj_bias', None)
355
+ self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
356
+
357
+ if add_bias_kv:
358
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
359
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
360
+ else:
361
+ self.bias_k = self.bias_v = None
362
+
363
+ self.add_zero_attn = add_zero_attn
364
+
365
+ self._reset_parameters()
366
+
367
+ def _reset_parameters(self):
368
+ if self._qkv_same_embed_dim:
369
+ xavier_uniform_(self.in_proj_weight)
370
+ else:
371
+ xavier_uniform_(self.q_proj_weight)
372
+ xavier_uniform_(self.k_proj_weight)
373
+ xavier_uniform_(self.v_proj_weight)
374
+
375
+ if self.in_proj_bias is not None:
376
+ constant_(self.in_proj_bias, 0.)
377
+ constant_(self.out_proj.bias, 0.)
378
+ if self.bias_k is not None:
379
+ xavier_normal_(self.bias_k)
380
+ if self.bias_v is not None:
381
+ xavier_normal_(self.bias_v)
382
+
383
+ def __setstate__(self, state):
384
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
385
+ if '_qkv_same_embed_dim' not in state:
386
+ state['_qkv_same_embed_dim'] = True
387
+
388
+ super(MultiheadAttention, self).__setstate__(state)
389
+
390
+ def forward(self, query, key, value, key_padding_mask=None,
391
+ need_weights=True, attn_mask=None):
392
+ # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
393
+ r"""
394
+ Args:
395
+ query, key, value: map a query and a set of key-value pairs to an output.
396
+ See "Attention Is All You Need" for more details.
397
+ key_padding_mask: if provided, specified padding elements in the key will
398
+ be ignored by the attention. This is an binary mask. When the value is True,
399
+ the corresponding value on the attention layer will be filled with -inf.
400
+ need_weights: output attn_output_weights.
401
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
402
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
403
+ Shape:
404
+ - Inputs:
405
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
406
+ the embedding dimension.
407
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
408
+ the embedding dimension.
409
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
410
+ the embedding dimension.
411
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
412
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
413
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
414
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
415
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
416
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
417
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
418
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
419
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
420
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
421
+ is provided, it will be added to the attention weight.
422
+ - Outputs:
423
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
424
+ E is the embedding dimension.
425
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
426
+ L is the target sequence length, S is the source sequence length.
427
+ """
428
+ if not self._qkv_same_embed_dim:
429
+ return multi_head_attention_forward(
430
+ query, key, value, self.embed_dim, self.num_heads,
431
+ self.in_proj_weight, self.in_proj_bias,
432
+ self.bias_k, self.bias_v, self.add_zero_attn,
433
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
434
+ training=self.training,
435
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
436
+ attn_mask=attn_mask, use_separate_proj_weight=True,
437
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
438
+ v_proj_weight=self.v_proj_weight)
439
+ else:
440
+ return multi_head_attention_forward(
441
+ query, key, value, self.embed_dim, self.num_heads,
442
+ self.in_proj_weight, self.in_proj_bias,
443
+ self.bias_k, self.bias_v, self.add_zero_attn,
444
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
445
+ training=self.training,
446
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
447
+ attn_mask=attn_mask)
448
+
449
+
450
+ class Transformer(Module):
451
+ r"""A transformer model. User is able to modify the attributes as needed. The architecture
452
+ is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
453
+ Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
454
+ Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
455
+ Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
456
+ model with corresponding parameters.
457
+
458
+ Args:
459
+ d_model: the number of expected features in the encoder/decoder inputs (default=512).
460
+ nhead: the number of heads in the multiheadattention models (default=8).
461
+ num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
462
+ num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
463
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
464
+ dropout: the dropout value (default=0.1).
465
+ activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
466
+ custom_encoder: custom encoder (default=None).
467
+ custom_decoder: custom decoder (default=None).
468
+
469
+ Examples::
470
+ >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
471
+ >>> src = torch.rand((10, 32, 512))
472
+ >>> tgt = torch.rand((20, 32, 512))
473
+ >>> out = transformer_model(src, tgt)
474
+
475
+ Note: A full example to apply nn.Transformer module for the word language model is available in
476
+ https://github.com/pytorch/examples/tree/master/word_language_model
477
+ """
478
+
479
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
480
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
481
+ activation="relu", custom_encoder=None, custom_decoder=None):
482
+ super(Transformer, self).__init__()
483
+
484
+ if custom_encoder is not None:
485
+ self.encoder = custom_encoder
486
+ else:
487
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
488
+ encoder_norm = LayerNorm(d_model)
489
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
490
+
491
+ if custom_decoder is not None:
492
+ self.decoder = custom_decoder
493
+ else:
494
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
495
+ decoder_norm = LayerNorm(d_model)
496
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
497
+
498
+ self._reset_parameters()
499
+
500
+ self.d_model = d_model
501
+ self.nhead = nhead
502
+
503
+ def forward(self, src, tgt, src_mask=None, tgt_mask=None,
504
+ memory_mask=None, src_key_padding_mask=None,
505
+ tgt_key_padding_mask=None, memory_key_padding_mask=None):
506
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor # noqa
507
+ r"""Take in and process masked source/target sequences.
508
+
509
+ Args:
510
+ src: the sequence to the encoder (required).
511
+ tgt: the sequence to the decoder (required).
512
+ src_mask: the additive mask for the src sequence (optional).
513
+ tgt_mask: the additive mask for the tgt sequence (optional).
514
+ memory_mask: the additive mask for the encoder output (optional).
515
+ src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
516
+ tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
517
+ memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
518
+
519
+ Shape:
520
+ - src: :math:`(S, N, E)`.
521
+ - tgt: :math:`(T, N, E)`.
522
+ - src_mask: :math:`(S, S)`.
523
+ - tgt_mask: :math:`(T, T)`.
524
+ - memory_mask: :math:`(T, S)`.
525
+ - src_key_padding_mask: :math:`(N, S)`.
526
+ - tgt_key_padding_mask: :math:`(N, T)`.
527
+ - memory_key_padding_mask: :math:`(N, S)`.
528
+
529
+ Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
530
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
531
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
532
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
533
+ is provided, it will be added to the attention weight.
534
+ [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
535
+ the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
536
+ positions will be unchanged. If a BoolTensor is provided, the positions with the
537
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
538
+
539
+ - output: :math:`(T, N, E)`.
540
+
541
+ Note: Due to the multi-head attention architecture in the transformer model,
542
+ the output sequence length of a transformer is same as the input sequence
543
+ (i.e. target) length of the decode.
544
+
545
+ where S is the source sequence length, T is the target sequence length, N is the
546
+ batch size, E is the feature number
547
+
548
+ Examples:
549
+ >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
550
+ """
551
+
552
+ if src.size(1) != tgt.size(1):
553
+ raise RuntimeError("the batch number of src and tgt must be equal")
554
+
555
+ if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
556
+ raise RuntimeError("the feature number of src and tgt must be equal to d_model")
557
+
558
+ memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
559
+ output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
560
+ tgt_key_padding_mask=tgt_key_padding_mask,
561
+ memory_key_padding_mask=memory_key_padding_mask)
562
+ return output
563
+
564
+ def generate_square_subsequent_mask(self, sz):
565
+ r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
566
+ Unmasked positions are filled with float(0.0).
567
+ """
568
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
569
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
570
+ return mask
571
+
572
+ def _reset_parameters(self):
573
+ r"""Initiate parameters in the transformer model."""
574
+
575
+ for p in self.parameters():
576
+ if p.dim() > 1:
577
+ xavier_uniform_(p)
578
+
579
+
580
+ class TransformerEncoder(Module):
581
+ r"""TransformerEncoder is a stack of N encoder layers
582
+
583
+ Args:
584
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
585
+ num_layers: the number of sub-encoder-layers in the encoder (required).
586
+ norm: the layer normalization component (optional).
587
+
588
+ Examples::
589
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
590
+ >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
591
+ >>> src = torch.rand(10, 32, 512)
592
+ >>> out = transformer_encoder(src)
593
+ """
594
+ __constants__ = ['norm']
595
+
596
+ def __init__(self, encoder_layer, num_layers, norm=None):
597
+ super(TransformerEncoder, self).__init__()
598
+ self.layers = _get_clones(encoder_layer, num_layers)
599
+ self.num_layers = num_layers
600
+ self.norm = norm
601
+
602
+ def forward(self, src, mask=None, src_key_padding_mask=None):
603
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
604
+ r"""Pass the input through the encoder layers in turn.
605
+
606
+ Args:
607
+ src: the sequence to the encoder (required).
608
+ mask: the mask for the src sequence (optional).
609
+ src_key_padding_mask: the mask for the src keys per batch (optional).
610
+
611
+ Shape:
612
+ see the docs in Transformer class.
613
+ """
614
+ output = src
615
+
616
+ for i, mod in enumerate(self.layers):
617
+ output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
618
+
619
+ if self.norm is not None:
620
+ output = self.norm(output)
621
+
622
+ return output
623
+
624
+
625
+ class TransformerDecoder(Module):
626
+ r"""TransformerDecoder is a stack of N decoder layers
627
+
628
+ Args:
629
+ decoder_layer: an instance of the TransformerDecoderLayer() class (required).
630
+ num_layers: the number of sub-decoder-layers in the decoder (required).
631
+ norm: the layer normalization component (optional).
632
+
633
+ Examples::
634
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
635
+ >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
636
+ >>> memory = torch.rand(10, 32, 512)
637
+ >>> tgt = torch.rand(20, 32, 512)
638
+ >>> out = transformer_decoder(tgt, memory)
639
+ """
640
+ __constants__ = ['norm']
641
+
642
+ def __init__(self, decoder_layer, num_layers, norm=None):
643
+ super(TransformerDecoder, self).__init__()
644
+ self.layers = _get_clones(decoder_layer, num_layers)
645
+ self.num_layers = num_layers
646
+ self.norm = norm
647
+
648
+ def forward(self, tgt, memory, memory2=None, tgt_mask=None,
649
+ memory_mask=None, memory_mask2=None, tgt_key_padding_mask=None,
650
+ memory_key_padding_mask=None, memory_key_padding_mask2=None):
651
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
652
+ r"""Pass the inputs (and mask) through the decoder layer in turn.
653
+
654
+ Args:
655
+ tgt: the sequence to the decoder (required).
656
+ memory: the sequence from the last layer of the encoder (required).
657
+ tgt_mask: the mask for the tgt sequence (optional).
658
+ memory_mask: the mask for the memory sequence (optional).
659
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
660
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
661
+
662
+ Shape:
663
+ see the docs in Transformer class.
664
+ """
665
+ output = tgt
666
+
667
+ for mod in self.layers:
668
+ output = mod(output, memory, memory2=memory2, tgt_mask=tgt_mask,
669
+ memory_mask=memory_mask, memory_mask2=memory_mask2,
670
+ tgt_key_padding_mask=tgt_key_padding_mask,
671
+ memory_key_padding_mask=memory_key_padding_mask,
672
+ memory_key_padding_mask2=memory_key_padding_mask2)
673
+
674
+ if self.norm is not None:
675
+ output = self.norm(output)
676
+
677
+ return output
678
+
679
+ class TransformerEncoderLayer(Module):
680
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
681
+ This standard encoder layer is based on the paper "Attention Is All You Need".
682
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
683
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
684
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
685
+ in a different way during application.
686
+
687
+ Args:
688
+ d_model: the number of expected features in the input (required).
689
+ nhead: the number of heads in the multiheadattention models (required).
690
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
691
+ dropout: the dropout value (default=0.1).
692
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
693
+
694
+ Examples::
695
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
696
+ >>> src = torch.rand(10, 32, 512)
697
+ >>> out = encoder_layer(src)
698
+ """
699
+
700
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
701
+ activation="relu", debug=False):
702
+ super(TransformerEncoderLayer, self).__init__()
703
+ self.debug = debug
704
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
705
+ # Implementation of Feedforward model
706
+ self.linear1 = Linear(d_model, dim_feedforward)
707
+ self.dropout = Dropout(dropout)
708
+ self.linear2 = Linear(dim_feedforward, d_model)
709
+
710
+ self.norm1 = LayerNorm(d_model)
711
+ self.norm2 = LayerNorm(d_model)
712
+ self.dropout1 = Dropout(dropout)
713
+ self.dropout2 = Dropout(dropout)
714
+
715
+ self.activation = _get_activation_fn(activation)
716
+
717
+ def __setstate__(self, state):
718
+ if 'activation' not in state:
719
+ state['activation'] = F.relu
720
+ super(TransformerEncoderLayer, self).__setstate__(state)
721
+
722
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
723
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
724
+ r"""Pass the input through the encoder layer.
725
+
726
+ Args:
727
+ src: the sequence to the encoder layer (required).
728
+ src_mask: the mask for the src sequence (optional).
729
+ src_key_padding_mask: the mask for the src keys per batch (optional).
730
+
731
+ Shape:
732
+ see the docs in Transformer class.
733
+ """
734
+ src2, attn = self.self_attn(src, src, src, attn_mask=src_mask,
735
+ key_padding_mask=src_key_padding_mask)
736
+ if self.debug: self.attn = attn
737
+ src = src + self.dropout1(src2)
738
+ src = self.norm1(src)
739
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
740
+ src = src + self.dropout2(src2)
741
+ src = self.norm2(src)
742
+
743
+ return src
744
+
745
+
746
+ class TransformerDecoderLayer(Module):
747
+ r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
748
+ This standard decoder layer is based on the paper "Attention Is All You Need".
749
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
750
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
751
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
752
+ in a different way during application.
753
+
754
+ Args:
755
+ d_model: the number of expected features in the input (required).
756
+ nhead: the number of heads in the multiheadattention models (required).
757
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
758
+ dropout: the dropout value (default=0.1).
759
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
760
+
761
+ Examples::
762
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
763
+ >>> memory = torch.rand(10, 32, 512)
764
+ >>> tgt = torch.rand(20, 32, 512)
765
+ >>> out = decoder_layer(tgt, memory)
766
+ """
767
+
768
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
769
+ activation="relu", self_attn=True, siamese=False, debug=False):
770
+ super(TransformerDecoderLayer, self).__init__()
771
+ self.has_self_attn, self.siamese = self_attn, siamese
772
+ self.debug = debug
773
+ if self.has_self_attn:
774
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
775
+ self.norm1 = LayerNorm(d_model)
776
+ self.dropout1 = Dropout(dropout)
777
+ self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
778
+ # Implementation of Feedforward model
779
+ self.linear1 = Linear(d_model, dim_feedforward)
780
+ self.dropout = Dropout(dropout)
781
+ self.linear2 = Linear(dim_feedforward, d_model)
782
+
783
+ self.norm2 = LayerNorm(d_model)
784
+ self.norm3 = LayerNorm(d_model)
785
+ self.dropout2 = Dropout(dropout)
786
+ self.dropout3 = Dropout(dropout)
787
+ if self.siamese:
788
+ self.multihead_attn2 = MultiheadAttention(d_model, nhead, dropout=dropout)
789
+
790
+ self.activation = _get_activation_fn(activation)
791
+
792
+ def __setstate__(self, state):
793
+ if 'activation' not in state:
794
+ state['activation'] = F.relu
795
+ super(TransformerDecoderLayer, self).__setstate__(state)
796
+
797
+ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
798
+ tgt_key_padding_mask=None, memory_key_padding_mask=None,
799
+ memory2=None, memory_mask2=None, memory_key_padding_mask2=None):
800
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
801
+ r"""Pass the inputs (and mask) through the decoder layer.
802
+
803
+ Args:
804
+ tgt: the sequence to the decoder layer (required).
805
+ memory: the sequence from the last layer of the encoder (required).
806
+ tgt_mask: the mask for the tgt sequence (optional).
807
+ memory_mask: the mask for the memory sequence (optional).
808
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
809
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
810
+
811
+ Shape:
812
+ see the docs in Transformer class.
813
+ """
814
+ if self.has_self_attn:
815
+ tgt2, attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
816
+ key_padding_mask=tgt_key_padding_mask)
817
+ tgt = tgt + self.dropout1(tgt2)
818
+ tgt = self.norm1(tgt)
819
+ if self.debug: self.attn = attn
820
+ tgt2, attn2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
821
+ key_padding_mask=memory_key_padding_mask)
822
+ if self.debug: self.attn2 = attn2
823
+
824
+ if self.siamese:
825
+ tgt3, attn3 = self.multihead_attn2(tgt, memory2, memory2, attn_mask=memory_mask2,
826
+ key_padding_mask=memory_key_padding_mask2)
827
+ tgt = tgt + self.dropout2(tgt3)
828
+ if self.debug: self.attn3 = attn3
829
+
830
+ tgt = tgt + self.dropout2(tgt2)
831
+ tgt = self.norm2(tgt)
832
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
833
+ tgt = tgt + self.dropout3(tgt2)
834
+ tgt = self.norm3(tgt)
835
+
836
+ return tgt
837
+
838
+
839
+ def _get_clones(module, N):
840
+ return ModuleList([copy.deepcopy(module) for i in range(N)])
841
+
842
+
843
+ def _get_activation_fn(activation):
844
+ if activation == "relu":
845
+ return F.relu
846
+ elif activation == "gelu":
847
+ return F.gelu
848
+
849
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
850
+
851
+
852
+ class PositionalEncoding(nn.Module):
853
+ r"""Inject some information about the relative or absolute position of the tokens
854
+ in the sequence. The positional encodings have the same dimension as
855
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
856
+ functions of different frequencies.
857
+ .. math::
858
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
859
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
860
+ \text{where pos is the word position and i is the embed idx)
861
+ Args:
862
+ d_model: the embed dim (required).
863
+ dropout: the dropout value (default=0.1).
864
+ max_len: the max. length of the incoming sequence (default=5000).
865
+ Examples:
866
+ >>> pos_encoder = PositionalEncoding(d_model)
867
+ """
868
+
869
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
870
+ super(PositionalEncoding, self).__init__()
871
+ self.dropout = nn.Dropout(p=dropout)
872
+
873
+ pe = torch.zeros(max_len, d_model)
874
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
875
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
876
+ pe[:, 0::2] = torch.sin(position * div_term)
877
+ pe[:, 1::2] = torch.cos(position * div_term)
878
+ pe = pe.unsqueeze(0).transpose(0, 1)
879
+ self.register_buffer('pe', pe)
880
+
881
+ def forward(self, x):
882
+ r"""Inputs of forward function
883
+ Args:
884
+ x: the sequence fed to the positional encoder model (required).
885
+ Shape:
886
+ x: [sequence length, batch size, embed dim]
887
+ output: [sequence length, batch size, embed dim]
888
+ Examples:
889
+ >>> output = pos_encoder(x)
890
+ """
891
+
892
+ x = x + self.pe[:x.size(0), :]
893
+ return self.dropout(x)
894
+
895
+
896
+ if __name__ == '__main__':
897
+ transformer_model = Transformer(nhead=16, num_encoder_layers=12)
898
+ src = torch.rand((10, 32, 512))
899
+ tgt = torch.rand((20, 32, 512))
900
+ out = transformer_model(src, tgt)
901
+ print(out)