|
|
|
|
|
|
|
""" |
|
Conformer speech translation model (pytorch). |
|
|
|
It is a fusion of `e2e_st_transformer.py` |
|
Refer to: https://arxiv.org/abs/2005.08100 |
|
|
|
""" |
|
|
|
from espnet.nets.pytorch_backend.conformer.encoder import Encoder |
|
from espnet.nets.pytorch_backend.e2e_st_transformer import E2E as E2ETransformer |
|
from espnet.nets.pytorch_backend.conformer.argument import ( |
|
add_arguments_conformer_common, |
|
verify_rel_pos_type, |
|
) |
|
|
|
|
|
class E2E(E2ETransformer): |
|
"""E2E module. |
|
|
|
:param int idim: dimension of inputs |
|
:param int odim: dimension of outputs |
|
:param Namespace args: argument Namespace containing options |
|
|
|
""" |
|
|
|
@staticmethod |
|
def add_arguments(parser): |
|
"""Add arguments.""" |
|
E2ETransformer.add_arguments(parser) |
|
E2E.add_conformer_arguments(parser) |
|
return parser |
|
|
|
@staticmethod |
|
def add_conformer_arguments(parser): |
|
"""Add arguments for conformer model.""" |
|
group = parser.add_argument_group("conformer model specific setting") |
|
group = add_arguments_conformer_common(group) |
|
return parser |
|
|
|
def __init__(self, idim, odim, args, ignore_id=-1): |
|
"""Construct an E2E object. |
|
|
|
:param int idim: dimension of inputs |
|
:param int odim: dimension of outputs |
|
:param Namespace args: argument Namespace containing options |
|
""" |
|
super().__init__(idim, odim, args, ignore_id) |
|
if args.transformer_attn_dropout_rate is None: |
|
args.transformer_attn_dropout_rate = args.dropout_rate |
|
|
|
|
|
args = verify_rel_pos_type(args) |
|
|
|
self.encoder = Encoder( |
|
idim=idim, |
|
attention_dim=args.adim, |
|
attention_heads=args.aheads, |
|
linear_units=args.eunits, |
|
num_blocks=args.elayers, |
|
input_layer=args.transformer_input_layer, |
|
dropout_rate=args.dropout_rate, |
|
positional_dropout_rate=args.dropout_rate, |
|
attention_dropout_rate=args.transformer_attn_dropout_rate, |
|
pos_enc_layer_type=args.transformer_encoder_pos_enc_layer_type, |
|
selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, |
|
activation_type=args.transformer_encoder_activation_type, |
|
macaron_style=args.macaron_style, |
|
use_cnn_module=args.use_cnn_module, |
|
cnn_module_kernel=args.cnn_module_kernel, |
|
) |
|
self.reset_parameters(args) |
|
|