Spaces:
Running
Running
File size: 6,601 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmengine.model import BaseModule
from mmocr.models.common.modules import (MultiHeadAttention,
PositionwiseFeedForward)
class TFEncoderLayer(BaseModule):
"""Transformer Encoder Layer.
Args:
d_model (int): The number of expected features
in the decoder inputs (default=512).
d_inner (int): The dimension of the feedforward
network model (default=256).
n_head (int): The number of heads in the
multiheadattention models (default=8).
d_k (int): Total number of features in key.
d_v (int): Total number of features in value.
dropout (float): Dropout layer on attn_output_weights.
qkv_bias (bool): Add bias in projection layer. Default: False.
act_cfg (dict): Activation cfg for feedforward module.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm')
or ('norm', 'self_attn', 'norm', 'ffn').
Default:None.
"""
def __init__(self,
d_model=512,
d_inner=256,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False,
act_cfg=dict(type='mmengine.GELU'),
operation_order=None):
super().__init__()
self.attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout)
self.norm1 = nn.LayerNorm(d_model)
self.mlp = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout, act_cfg=act_cfg)
self.norm2 = nn.LayerNorm(d_model)
self.operation_order = operation_order
if self.operation_order is None:
self.operation_order = ('norm', 'self_attn', 'norm', 'ffn')
assert self.operation_order in [('norm', 'self_attn', 'norm', 'ffn'),
('self_attn', 'norm', 'ffn', 'norm')]
def forward(self, x, mask=None):
if self.operation_order == ('self_attn', 'norm', 'ffn', 'norm'):
residual = x
x = residual + self.attn(x, x, x, mask)
x = self.norm1(x)
residual = x
x = residual + self.mlp(x)
x = self.norm2(x)
elif self.operation_order == ('norm', 'self_attn', 'norm', 'ffn'):
residual = x
x = self.norm1(x)
x = residual + self.attn(x, x, x, mask)
residual = x
x = self.norm2(x)
x = residual + self.mlp(x)
return x
class TFDecoderLayer(nn.Module):
"""Transformer Decoder Layer.
Args:
d_model (int): The number of expected features
in the decoder inputs (default=512).
d_inner (int): The dimension of the feedforward
network model (default=256).
n_head (int): The number of heads in the
multiheadattention models (default=8).
d_k (int): Total number of features in key.
d_v (int): Total number of features in value.
dropout (float): Dropout layer on attn_output_weights.
qkv_bias (bool): Add bias in projection layer. Default: False.
act_cfg (dict): Activation cfg for feedforward module.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'enc_dec_attn',
'norm', 'ffn', 'norm') or ('norm', 'self_attn', 'norm',
'enc_dec_attn', 'norm', 'ffn').
Default:None.
"""
def __init__(self,
d_model=512,
d_inner=256,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False,
act_cfg=dict(type='mmengine.GELU'),
operation_order=None):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.self_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
self.enc_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
self.mlp = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout, act_cfg=act_cfg)
self.operation_order = operation_order
if self.operation_order is None:
self.operation_order = ('norm', 'self_attn', 'norm',
'enc_dec_attn', 'norm', 'ffn')
assert self.operation_order in [
('norm', 'self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn'),
('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn', 'norm')
]
def forward(self,
dec_input,
enc_output,
self_attn_mask=None,
dec_enc_attn_mask=None):
if self.operation_order == ('self_attn', 'norm', 'enc_dec_attn',
'norm', 'ffn', 'norm'):
dec_attn_out = self.self_attn(dec_input, dec_input, dec_input,
self_attn_mask)
dec_attn_out += dec_input
dec_attn_out = self.norm1(dec_attn_out)
enc_dec_attn_out = self.enc_attn(dec_attn_out, enc_output,
enc_output, dec_enc_attn_mask)
enc_dec_attn_out += dec_attn_out
enc_dec_attn_out = self.norm2(enc_dec_attn_out)
mlp_out = self.mlp(enc_dec_attn_out)
mlp_out += enc_dec_attn_out
mlp_out = self.norm3(mlp_out)
elif self.operation_order == ('norm', 'self_attn', 'norm',
'enc_dec_attn', 'norm', 'ffn'):
dec_input_norm = self.norm1(dec_input)
dec_attn_out = self.self_attn(dec_input_norm, dec_input_norm,
dec_input_norm, self_attn_mask)
dec_attn_out += dec_input
enc_dec_attn_in = self.norm2(dec_attn_out)
enc_dec_attn_out = self.enc_attn(enc_dec_attn_in, enc_output,
enc_output, dec_enc_attn_mask)
enc_dec_attn_out += dec_attn_out
mlp_out = self.mlp(self.norm3(enc_dec_attn_out))
mlp_out += enc_dec_attn_out
return mlp_out
|