|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from fairseq import checkpoint_utils |
|
from fairseq.models import ( |
|
register_model, |
|
register_model_architecture, |
|
) |
|
from fairseq.models.speech_to_text import ( |
|
ConvTransformerModel, |
|
convtransformer_espnet, |
|
ConvTransformerEncoder, |
|
) |
|
from fairseq.models.speech_to_text.modules.augmented_memory_attention import ( |
|
augmented_memory, |
|
SequenceEncoder, |
|
AugmentedMemoryConvTransformerEncoder, |
|
) |
|
|
|
from torch import nn, Tensor |
|
from typing import Dict, List |
|
from fairseq.models.speech_to_text.modules.emformer import NoSegAugmentedMemoryTransformerEncoderLayer |
|
|
|
@register_model("convtransformer_simul_trans") |
|
class SimulConvTransformerModel(ConvTransformerModel): |
|
""" |
|
Implementation of the paper: |
|
|
|
SimulMT to SimulST: Adapting Simultaneous Text Translation to |
|
End-to-End Simultaneous Speech Translation |
|
|
|
https://www.aclweb.org/anthology/2020.aacl-main.58.pdf |
|
""" |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
super(SimulConvTransformerModel, SimulConvTransformerModel).add_args(parser) |
|
parser.add_argument( |
|
"--train-monotonic-only", |
|
action="store_true", |
|
default=False, |
|
help="Only train monotonic attention", |
|
) |
|
|
|
@classmethod |
|
def build_decoder(cls, args, task, embed_tokens): |
|
tgt_dict = task.tgt_dict |
|
|
|
from examples.simultaneous_translation.models.transformer_monotonic_attention import ( |
|
TransformerMonotonicDecoder, |
|
) |
|
|
|
decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens) |
|
|
|
if getattr(args, "load_pretrained_decoder_from", None): |
|
decoder = checkpoint_utils.load_pretrained_component_from_model( |
|
component=decoder, checkpoint=args.load_pretrained_decoder_from |
|
) |
|
return decoder |
|
|
|
|
|
@register_model_architecture( |
|
"convtransformer_simul_trans", "convtransformer_simul_trans_espnet" |
|
) |
|
def convtransformer_simul_trans_espnet(args): |
|
convtransformer_espnet(args) |
|
|
|
|
|
@register_model("convtransformer_augmented_memory") |
|
@augmented_memory |
|
class AugmentedMemoryConvTransformerModel(SimulConvTransformerModel): |
|
@classmethod |
|
def build_encoder(cls, args): |
|
encoder = SequenceEncoder(args, AugmentedMemoryConvTransformerEncoder(args)) |
|
|
|
if getattr(args, "load_pretrained_encoder_from", None) is not None: |
|
encoder = checkpoint_utils.load_pretrained_component_from_model( |
|
component=encoder, checkpoint=args.load_pretrained_encoder_from |
|
) |
|
|
|
return encoder |
|
|
|
|
|
@register_model_architecture( |
|
"convtransformer_augmented_memory", "convtransformer_augmented_memory" |
|
) |
|
def augmented_memory_convtransformer_espnet(args): |
|
convtransformer_espnet(args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvTransformerEmformerEncoder(ConvTransformerEncoder): |
|
def __init__(self, args): |
|
super().__init__(args) |
|
stride = self.conv_layer_stride(args) |
|
trf_left_context = args.segment_left_context // stride |
|
trf_right_context = args.segment_right_context // stride |
|
context_config = [trf_left_context, trf_right_context] |
|
self.transformer_layers = nn.ModuleList( |
|
[ |
|
NoSegAugmentedMemoryTransformerEncoderLayer( |
|
input_dim=args.encoder_embed_dim, |
|
num_heads=args.encoder_attention_heads, |
|
ffn_dim=args.encoder_ffn_embed_dim, |
|
num_layers=args.encoder_layers, |
|
dropout_in_attn=args.dropout, |
|
dropout_on_attn=args.dropout, |
|
dropout_on_fc1=args.dropout, |
|
dropout_on_fc2=args.dropout, |
|
activation_fn=args.activation_fn, |
|
context_config=context_config, |
|
segment_size=args.segment_length, |
|
max_memory_size=args.max_memory_size, |
|
scaled_init=True, |
|
tanh_on_mem=args.amtrf_tanh_on_mem, |
|
) |
|
] |
|
) |
|
self.conv_transformer_encoder = ConvTransformerEncoder(args) |
|
|
|
def forward(self, src_tokens, src_lengths): |
|
encoder_out: Dict[str, List[Tensor]] = self.conv_transformer_encoder(src_tokens, src_lengths.to(src_tokens.device)) |
|
output = encoder_out["encoder_out"][0] |
|
encoder_padding_masks = encoder_out["encoder_padding_mask"] |
|
|
|
return { |
|
"encoder_out": [output], |
|
|
|
|
|
"encoder_padding_mask": [encoder_padding_masks[0][:, : output.size(0)]] if len(encoder_padding_masks) > 0 |
|
else [], |
|
"encoder_embedding": [], |
|
"encoder_states": [], |
|
"src_tokens": [], |
|
"src_lengths": [], |
|
} |
|
|
|
@staticmethod |
|
def conv_layer_stride(args): |
|
|
|
return 4 |
|
|
|
|
|
@register_model("convtransformer_emformer") |
|
class ConvtransformerEmformer(SimulConvTransformerModel): |
|
@staticmethod |
|
def add_args(parser): |
|
super(ConvtransformerEmformer, ConvtransformerEmformer).add_args(parser) |
|
|
|
parser.add_argument( |
|
"--segment-length", |
|
type=int, |
|
metavar="N", |
|
help="length of each segment (not including left context / right context)", |
|
) |
|
parser.add_argument( |
|
"--segment-left-context", |
|
type=int, |
|
help="length of left context in a segment", |
|
) |
|
parser.add_argument( |
|
"--segment-right-context", |
|
type=int, |
|
help="length of right context in a segment", |
|
) |
|
parser.add_argument( |
|
"--max-memory-size", |
|
type=int, |
|
default=-1, |
|
help="Right context for the segment.", |
|
) |
|
parser.add_argument( |
|
"--amtrf-tanh-on-mem", |
|
default=False, |
|
action="store_true", |
|
help="whether to use tanh on memory vector", |
|
) |
|
|
|
@classmethod |
|
def build_encoder(cls, args): |
|
encoder = ConvTransformerEmformerEncoder(args) |
|
if getattr(args, "load_pretrained_encoder_from", None): |
|
encoder = checkpoint_utils.load_pretrained_component_from_model( |
|
component=encoder, checkpoint=args.load_pretrained_encoder_from |
|
) |
|
return encoder |
|
|
|
|
|
@register_model_architecture( |
|
"convtransformer_emformer", |
|
"convtransformer_emformer", |
|
) |
|
def convtransformer_emformer_base(args): |
|
convtransformer_espnet(args) |
|
|