MMOCR / mmocr /models /common /layers /transformer_layers.py
tomofi's picture
Add application file
2366e36
raw
history blame
No virus
6.59 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.runner import BaseModule
from mmocr.models.common.modules import (MultiHeadAttention,
PositionwiseFeedForward)
class TFEncoderLayer(BaseModule):
"""Transformer Encoder Layer.
Args:
d_model (int): The number of expected features
in the decoder inputs (default=512).
d_inner (int): The dimension of the feedforward
network model (default=256).
n_head (int): The number of heads in the
multiheadattention models (default=8).
d_k (int): Total number of features in key.
d_v (int): Total number of features in value.
dropout (float): Dropout layer on attn_output_weights.
qkv_bias (bool): Add bias in projection layer. Default: False.
act_cfg (dict): Activation cfg for feedforward module.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm')
or ('norm', 'self_attn', 'norm', 'ffn').
Default:None.
"""
def __init__(self,
d_model=512,
d_inner=256,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False,
act_cfg=dict(type='mmcv.GELU'),
operation_order=None):
super().__init__()
self.attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout)
self.norm1 = nn.LayerNorm(d_model)
self.mlp = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout, act_cfg=act_cfg)
self.norm2 = nn.LayerNorm(d_model)
self.operation_order = operation_order
if self.operation_order is None:
self.operation_order = ('norm', 'self_attn', 'norm', 'ffn')
assert self.operation_order in [('norm', 'self_attn', 'norm', 'ffn'),
('self_attn', 'norm', 'ffn', 'norm')]
def forward(self, x, mask=None):
if self.operation_order == ('self_attn', 'norm', 'ffn', 'norm'):
residual = x
x = residual + self.attn(x, x, x, mask)
x = self.norm1(x)
residual = x
x = residual + self.mlp(x)
x = self.norm2(x)
elif self.operation_order == ('norm', 'self_attn', 'norm', 'ffn'):
residual = x
x = self.norm1(x)
x = residual + self.attn(x, x, x, mask)
residual = x
x = self.norm2(x)
x = residual + self.mlp(x)
return x
class TFDecoderLayer(nn.Module):
"""Transformer Decoder Layer.
Args:
d_model (int): The number of expected features
in the decoder inputs (default=512).
d_inner (int): The dimension of the feedforward
network model (default=256).
n_head (int): The number of heads in the
multiheadattention models (default=8).
d_k (int): Total number of features in key.
d_v (int): Total number of features in value.
dropout (float): Dropout layer on attn_output_weights.
qkv_bias (bool): Add bias in projection layer. Default: False.
act_cfg (dict): Activation cfg for feedforward module.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'enc_dec_attn',
'norm', 'ffn', 'norm') or ('norm', 'self_attn', 'norm',
'enc_dec_attn', 'norm', 'ffn').
Default:None.
"""
def __init__(self,
d_model=512,
d_inner=256,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False,
act_cfg=dict(type='mmcv.GELU'),
operation_order=None):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.self_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
self.enc_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
self.mlp = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout, act_cfg=act_cfg)
self.operation_order = operation_order
if self.operation_order is None:
self.operation_order = ('norm', 'self_attn', 'norm',
'enc_dec_attn', 'norm', 'ffn')
assert self.operation_order in [
('norm', 'self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn'),
('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn', 'norm')
]
def forward(self,
dec_input,
enc_output,
self_attn_mask=None,
dec_enc_attn_mask=None):
if self.operation_order == ('self_attn', 'norm', 'enc_dec_attn',
'norm', 'ffn', 'norm'):
dec_attn_out = self.self_attn(dec_input, dec_input, dec_input,
self_attn_mask)
dec_attn_out += dec_input
dec_attn_out = self.norm1(dec_attn_out)
enc_dec_attn_out = self.enc_attn(dec_attn_out, enc_output,
enc_output, dec_enc_attn_mask)
enc_dec_attn_out += dec_attn_out
enc_dec_attn_out = self.norm2(enc_dec_attn_out)
mlp_out = self.mlp(enc_dec_attn_out)
mlp_out += enc_dec_attn_out
mlp_out = self.norm3(mlp_out)
elif self.operation_order == ('norm', 'self_attn', 'norm',
'enc_dec_attn', 'norm', 'ffn'):
dec_input_norm = self.norm1(dec_input)
dec_attn_out = self.self_attn(dec_input_norm, dec_input_norm,
dec_input_norm, self_attn_mask)
dec_attn_out += dec_input
enc_dec_attn_in = self.norm2(dec_attn_out)
enc_dec_attn_out = self.enc_attn(enc_dec_attn_in, enc_output,
enc_output, dec_enc_attn_mask)
enc_dec_attn_out += dec_attn_out
mlp_out = self.mlp(self.norm3(enc_dec_attn_out))
mlp_out += enc_dec_attn_out
return mlp_out