|
""" |
|
Implementation of "Attention is All You Need" |
|
""" |
|
|
|
import torch.nn as nn |
|
|
|
from onmt.encoders.encoder import EncoderBase |
|
from onmt.modules import MultiHeadedAttention |
|
from onmt.modules.position_ffn import PositionwiseFeedForward |
|
from onmt.modules.position_ffn import ActivationFunction |
|
from onmt.utils.misc import sequence_mask |
|
from onmt.modules.rmsnorm import RMSNorm |
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
""" |
|
A single layer of the transformer encoder. |
|
|
|
Args: |
|
d_model (int): the dimension of keys/values/queries in |
|
MultiHeadedAttention, also the input size of |
|
the first-layer of the PositionwiseFeedForward. |
|
heads (int): the number of head for MultiHeadedAttention. |
|
d_ff (int): the second-layer of the PositionwiseFeedForward. |
|
dropout (float): dropout probability(0-1.0). |
|
pos_ffn_activation_fn (ActivationFunction): |
|
activation function choice for PositionwiseFeedForward layer |
|
""" |
|
|
|
def __init__( |
|
self, |
|
d_model, |
|
heads, |
|
d_ff, |
|
dropout, |
|
attention_dropout, |
|
max_relative_positions=0, |
|
relative_positions_buckets=0, |
|
pos_ffn_activation_fn=ActivationFunction.relu, |
|
add_qkvbias=False, |
|
num_kv=0, |
|
add_ffnbias=True, |
|
parallel_residual=False, |
|
layer_norm="standard", |
|
norm_eps=1e-6, |
|
use_ckpting=[], |
|
parallel_gpu=1, |
|
): |
|
super(TransformerEncoderLayer, self).__init__() |
|
|
|
self.self_attn = MultiHeadedAttention( |
|
heads, |
|
d_model, |
|
dropout=attention_dropout, |
|
is_decoder=False, |
|
max_relative_positions=max_relative_positions, |
|
relative_positions_buckets=relative_positions_buckets, |
|
attn_type="self", |
|
add_qkvbias=add_qkvbias, |
|
num_kv=num_kv, |
|
use_ckpting=use_ckpting, |
|
parallel_gpu=parallel_gpu, |
|
) |
|
self.feed_forward = PositionwiseFeedForward( |
|
d_model, |
|
d_ff, |
|
dropout, |
|
pos_ffn_activation_fn, |
|
add_ffnbias, |
|
parallel_residual, |
|
layer_norm, |
|
norm_eps, |
|
use_ckpting=use_ckpting, |
|
parallel_gpu=parallel_gpu, |
|
) |
|
self.parallel_residual = parallel_residual |
|
if layer_norm == "standard": |
|
self.layer_norm = nn.LayerNorm(d_model, eps=norm_eps) |
|
elif layer_norm == "rms": |
|
self.layer_norm = RMSNorm(d_model, eps=norm_eps) |
|
else: |
|
raise ValueError(f"{layer_norm} layer norm type is not supported") |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, layer_in, mask): |
|
""" |
|
Args: |
|
layer_in (FloatTensor): ``(batch_size, src_len, model_dim)`` |
|
mask (LongTensor): ``(batch_size, 1, src_len)`` |
|
|
|
Returns: |
|
(FloatTensor): |
|
* layer_out ``(batch_size, src_len, model_dim)`` |
|
""" |
|
norm_layer_in = self.layer_norm(layer_in) |
|
context, _ = self.self_attn( |
|
norm_layer_in, norm_layer_in, norm_layer_in, mask=mask |
|
) |
|
if self.parallel_residual: |
|
|
|
layer_out = ( |
|
self.feed_forward(norm_layer_in) |
|
- norm_layer_in |
|
+ layer_in |
|
+ self.dropout(context) |
|
) |
|
else: |
|
layer_out = self.dropout(context) + layer_in |
|
layer_out = self.feed_forward(layer_out) |
|
|
|
return layer_out |
|
|
|
def update_dropout(self, dropout, attention_dropout): |
|
self.self_attn.update_dropout(attention_dropout) |
|
self.feed_forward.update_dropout(dropout) |
|
self.dropout.p = dropout |
|
|
|
|
|
class TransformerEncoder(EncoderBase): |
|
"""The Transformer encoder from "Attention is All You Need" |
|
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17` |
|
|
|
Args: |
|
num_layers (int): number of encoder layers |
|
d_model (int): size of the model |
|
heads (int): number of heads |
|
d_ff (int): size of the inner FF layer |
|
dropout (float): dropout parameters |
|
embeddings (onmt.modules.Embeddings): |
|
embeddings to use, should have positional encodings |
|
pos_ffn_activation_fn (ActivationFunction): |
|
activation function choice for PositionwiseFeedForward layer |
|
|
|
Returns: |
|
(torch.FloatTensor, torch.FloatTensor): |
|
|
|
* enc_out ``(batch_size, src_len, model_dim)`` |
|
* encoder final state: None in the case of Transformer |
|
* src_len ``(batch_size)`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_layers, |
|
d_model, |
|
heads, |
|
d_ff, |
|
dropout, |
|
attention_dropout, |
|
embeddings, |
|
max_relative_positions, |
|
relative_positions_buckets, |
|
pos_ffn_activation_fn=ActivationFunction.relu, |
|
add_qkvbias=False, |
|
num_kv=0, |
|
add_ffnbias=True, |
|
parallel_residual=False, |
|
layer_norm="standard", |
|
norm_eps=1e-6, |
|
use_ckpting=[], |
|
parallel_gpu=1, |
|
): |
|
super(TransformerEncoder, self).__init__() |
|
|
|
self.embeddings = embeddings |
|
self.transformer = nn.ModuleList( |
|
[ |
|
TransformerEncoderLayer( |
|
d_model, |
|
heads, |
|
d_ff, |
|
dropout, |
|
attention_dropout, |
|
max_relative_positions=max_relative_positions, |
|
relative_positions_buckets=relative_positions_buckets, |
|
pos_ffn_activation_fn=pos_ffn_activation_fn, |
|
add_qkvbias=add_qkvbias, |
|
num_kv=num_kv, |
|
add_ffnbias=add_ffnbias, |
|
parallel_residual=parallel_residual, |
|
layer_norm=layer_norm, |
|
norm_eps=norm_eps, |
|
use_ckpting=use_ckpting, |
|
parallel_gpu=parallel_gpu, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
if layer_norm == "standard": |
|
self.layer_norm = nn.LayerNorm(d_model, eps=norm_eps) |
|
elif layer_norm == "rms": |
|
self.layer_norm = RMSNorm(d_model, eps=norm_eps) |
|
else: |
|
raise ValueError(f"{layer_norm} layer norm type is not supported") |
|
|
|
@classmethod |
|
def from_opt(cls, opt, embeddings): |
|
"""Alternate constructor.""" |
|
return cls( |
|
opt.enc_layers, |
|
opt.enc_hid_size, |
|
opt.heads, |
|
opt.transformer_ff, |
|
opt.dropout[0] if type(opt.dropout) is list else opt.dropout, |
|
opt.attention_dropout[0] |
|
if type(opt.attention_dropout) is list |
|
else opt.attention_dropout, |
|
embeddings, |
|
opt.max_relative_positions, |
|
opt.relative_positions_buckets, |
|
pos_ffn_activation_fn=opt.pos_ffn_activation_fn, |
|
add_qkvbias=opt.add_qkvbias, |
|
num_kv=opt.num_kv, |
|
add_ffnbias=opt.add_ffnbias, |
|
parallel_residual=opt.parallel_residual, |
|
layer_norm=opt.layer_norm, |
|
norm_eps=opt.norm_eps, |
|
use_ckpting=opt.use_ckpting, |
|
parallel_gpu=opt.world_size |
|
if opt.parallel_mode == "tensor_parallel" |
|
else 1, |
|
) |
|
|
|
def forward(self, src, src_len=None): |
|
"""See :func:`EncoderBase.forward()`""" |
|
enc_out = self.embeddings(src) |
|
mask = ~sequence_mask(src_len).unsqueeze(1) |
|
mask = mask.unsqueeze(1) |
|
mask = mask.expand(-1, -1, mask.size(3), -1) |
|
|
|
|
|
|
|
|
|
for layer in self.transformer: |
|
enc_out = layer(enc_out, mask) |
|
enc_out = self.layer_norm(enc_out) |
|
return enc_out, None, src_len |
|
|
|
def update_dropout(self, dropout, attention_dropout): |
|
self.embeddings.update_dropout(dropout) |
|
for layer in self.transformer: |
|
layer.update_dropout(dropout, attention_dropout) |
|
|