JustinLin610
update
10b0761
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from collections import namedtuple
import torch
import torch.nn as nn
from fairseq import checkpoint_utils
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
FairseqDecoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.speech_to_text import (
TransformerDecoder,
S2TTransformerEncoder,
)
from fairseq.models.transformer import TransformerEncoder
from fairseq.modules import (
TransformerEncoderLayer,
GradMultiply,
LayerNorm,
)
logger = logging.getLogger(__name__)
class SpeechEoSEncoder(FairseqEncoder):
def __init__(self, encoder, eos_num, feat_dim, adapter_type="None", adapter_dim=0):
super().__init__(None)
self.encoder = encoder
self.eos_num = eos_num # downsampling rate for speech input feature
self.eos_emb = (
nn.Parameter(torch.zeros(1, feat_dim), requires_grad=True)
if eos_num > 0
else None
)
self.adapter = self.add_adapter(adapter_type, adapter_dim)
def add_adapter(self, adapter_type, adapter_dim):
def _make_identity(linear, eps=1e-5):
assert isinstance(linear, nn.Linear)
linear.weight.data.mul_(eps)
linear.weight.data.fill_diagonal_(1.0)
if linear.bias is not None:
linear.bias.data.mul_(eps)
adapter = None
if adapter_type == "Linear":
assert adapter_dim > 0
adapter = nn.Sequential(
nn.Linear(adapter_dim, adapter_dim), LayerNorm(adapter_dim)
)
# initialize the adapter as identity matrix first
_make_identity(adapter[0])
elif adapter_type == "MLP":
assert adapter_dim > 0
# assume the model is pre-norm model
adapter = nn.Sequential(
nn.Linear(adapter_dim, 2 * adapter_dim),
nn.ReLU(),
nn.Linear(2 * adapter_dim, adapter_dim),
LayerNorm(adapter_dim),
)
_make_identity(adapter[0])
_make_identity(adapter[2])
return adapter
def add_eos(self, src_tokens, src_lengths):
bsz, max_seq_len, fdim = src_tokens.size()
if self.eos_num > 0:
src_token_eos = torch.zeros(
[bsz, max_seq_len + self.eos_num, fdim],
dtype=src_tokens.dtype,
device=src_tokens.device,
)
src_token_eos[:, :max_seq_len] = src_tokens
for bi in range(bsz):
src_token_eos[bi][
src_lengths[bi] : src_lengths[bi] + self.eos_num
] = self.eos_emb.expand(self.eos_num, fdim)
src_lengths = src_lengths + self.eos_num
src_tokens = src_token_eos
return src_tokens, src_lengths
def apply_adapter(self, enc_out):
if self.adapter is None:
return enc_out
rst = self.adapter(enc_out.encoder_out)
if enc_out.encoder_padding_mask is not None:
rst.masked_fill_(
enc_out.encoder_padding_mask.transpose(0, 1).unsqueeze(-1), 0
)
return EncoderOut(
encoder_out=rst,
encoder_padding_mask=enc_out.encoder_padding_mask,
encoder_embedding=enc_out.encoder_embedding,
encoder_states=enc_out.encoder_states,
src_tokens=enc_out.src_tokens,
src_lengths=enc_out.src_lengths,
)
def forward(self, src_tokens, src_lengths=None, return_all_hiddens=False, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
src_tokens, src_lengths = self.add_eos(src_tokens, src_lengths)
enc_out = self.encoder(src_tokens, src_lengths, return_all_hiddens)
enc_out = self.apply_adapter(enc_out)
return enc_out
def reorder_encoder_out(self, encoder_out, new_order):
return self.encoder.reorder_encoder_out(encoder_out, new_order)
class DualInputEncoder(FairseqEncoder):
def __init__(
self,
args,
spch_encoder,
text_encoder,
dictionary,
cross_attentive_loss_before_last_layer=-1,
):
super().__init__(dictionary)
self.spch_encoder = spch_encoder
self.text_encoder = text_encoder
self.enc_grad_mult = args.enc_grad_mult
self.cross_attentive_loss_before_last_layer = (
cross_attentive_loss_before_last_layer
)
self.use_cross_attentive_loss = (
False if cross_attentive_loss_before_last_layer <= -1 else True
)
self.enc2_along_grad_mult = args.enc2_along_grad_mult
@classmethod
def set_shared_layer(cls, share_level, src_layer, tgt_layer):
"""
share parameters from tgt_layer to src_layer
share_level:
0: share everything
1: share everything but different model
2: share weight but not bias, layernorm
"""
if share_level == 0:
return tgt_layer
if isinstance(src_layer, nn.Linear):
return tgt_layer
if isinstance(src_layer, TransformerEncoderLayer):
assert src_layer.embed_dim == tgt_layer.embed_dim
assert src_layer.normalize_before == tgt_layer.normalize_before
if share_level == 1:
src_layer.fc1 = tgt_layer.fc1
src_layer.fc2 = tgt_layer.fc2
src_layer.self_attn = tgt_layer.self_attn
src_layer.final_layer_norm = tgt_layer.final_layer_norm
src_layer.self_attn_layer_norm = tgt_layer.self_attn_layer_norm
src_layer.layernorm_embedding = tgt_layer.layernorm_embedding
else:
src_layer.fc1.weight = tgt_layer.fc1.weight
src_layer.fc2.weight = tgt_layer.fc2.weight
src_layer.self_attn.k_proj.weight = tgt_layer.self_attn.k_proj.weight
src_layer.self_attn.v_proj.weight = tgt_layer.self_attn.v_proj.weight
src_layer.self_attn.q_proj.weight = tgt_layer.self_attn.q_proj.weight
src_layer.self_attn.out_proj.weight = (
tgt_layer.self_attn.out_proj.weight
)
else:
if share_level == 1:
return tgt_layer
return src_layer
@classmethod
def build_spch_encoder(cls, args):
cfg = {
"input_feat_per_channel": args.input_feat_per_channel,
"input_channels": args.input_channels,
"conv_kernel_sizes": args.conv_kernel_sizes,
"conv_channels": args.conv_channels,
"encoder_embed_dim": args.encoder_embed_dim,
"encoder_ffn_embed_dim": args.encoder_ffn_embed_dim,
"encoder_layers": args.speech_encoder_layers,
"encoder_layerdrop": args.encoder_layerdrop,
"encoder_attention_heads": args.encoder_attention_heads,
"max_source_positions": args.max_source_positions,
"dropout": args.dropout,
"encoder_normalize_before": args.encoder_normalize_before,
"activation_dropout": args.activation_dropout,
"attention_dropout": args.attention_dropout,
"activation_fn": args.activation_fn,
"layernorm_embedding": args.layernorm_embedding,
"no_token_positional_embeddings": args.no_token_positional_embeddings,
"no_scale_embedding": args.no_scale_embedding,
"quant_noise_pq": args.quant_noise_pq,
"encoder_freezing_updates": 0,
}
model_args = namedtuple("args", cfg.keys())(*cfg.values())
spch_encoder = S2TTransformerEncoder(model_args)
if args.add_speech_eos:
spch_encoder = SpeechEoSEncoder(
spch_encoder,
2 * len(args.conv_kernel_sizes.split(",")),
args.input_feat_per_channel,
adapter_type=getattr(args, "speech_encoder_adapter_type", "None"),
adapter_dim=args.encoder_embed_dim,
)
return spch_encoder
@classmethod
def build_text_encoder(cls, args, src_dictionary, spch_encoder):
if args.encoder_shared_layers > 0:
mx_shared_layers = (
args.speech_encoder_layers
if args.speech_encoder_layers < args.text_encoder_layers
else args.text_encoder_layers
)
args.encoder_shared_layers = (
args.encoder_shared_layers
if args.encoder_shared_layers <= mx_shared_layers
else mx_shared_layers
)
cfg = {
"encoder_embed_dim": args.encoder_text_embed_dim,
"encoder_ffn_embed_dim": args.encoder_ffn_embed_dim,
"encoder_layers": args.text_encoder_layers,
"encoder_layerdrop": args.encoder_layerdrop,
"encoder_attention_heads": args.encoder_attention_heads,
"encoder_learned_pos": args.encoder_learned_pos,
"max_source_positions": args.max_source_positions,
"dropout": args.dropout,
"encoder_normalize_before": args.encoder_normalize_before,
"activation_dropout": args.activation_dropout,
"attention_dropout": args.attention_dropout,
"activation_fn": args.activation_fn,
"adaptive_input": args.adaptive_input,
"no_token_positional_embeddings": args.no_token_positional_embeddings,
"no_scale_embedding": args.no_scale_embedding,
"quant_noise_pq": args.quant_noise_pq,
}
model_args = namedtuple("args", cfg.keys())(*cfg.values())
enc_emb = nn.Embedding(
len(src_dictionary), model_args.encoder_embed_dim, src_dictionary.pad()
)
text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb)
if args.add_speech_eos:
spch_encoder = spch_encoder.encoder
if args.encoder_shared_layers > 0:
text_encoder.layer_norm = cls.set_shared_layer(
args.encoder_shared_layer_level,
text_encoder.layer_norm,
spch_encoder.layer_norm,
)
for i, ly in enumerate(
spch_encoder.transformer_layers[-args.encoder_shared_layers :]
):
ly_id = i + args.text_encoder_layers - args.encoder_shared_layers
assert isinstance(text_encoder.layers[ly_id], type(ly))
text_encoder.layers[ly_id] = cls.set_shared_layer(
args.encoder_shared_layer_level,
text_encoder.layers[ly_id],
ly,
)
return text_encoder
def mult_rst_grad(self, rst, ratio):
assert isinstance(rst, dict) # instead of EncoderOut
assert len(rst["encoder_out"]) == 1
rst["encoder_out"][0] = GradMultiply.apply(rst["encoder_out"][0], ratio)
return rst
def process_attentive_loss_states(self, rst, interstates):
assert isinstance(rst, dict) # instead of EncoderOut
rst["encoder_states"] = interstates
return rst
def forward(
self,
src_tokens,
src_lengths=None,
src_txt_tokens=None,
src_txt_lengths=None,
**kwargs
):
"""
Args:
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (speech) (B,)
src_txt_tokens: padded tensor (B, T)
src_txt_lengths: tensor of original lengths of input utterances (text) (B,)
"""
# src_tokens only: inference
# src_tokens, src_lengths: speech only training
# src_txt_tokens, src_txt_lengths: text only training
# all valid: speech + text training
if src_tokens is None and src_txt_tokens is None:
raise ValueError(
"src_tokens and src_txt_tokens cannot be None at the same time"
)
ret1 = None
ret2 = None
return_all_hiddens = False
if src_tokens is not None:
if (
self.use_cross_attentive_loss and src_txt_tokens is not None
): # remove self.training so we can get attn score during validation step
return_all_hiddens = True
ret1 = self.spch_encoder(
src_tokens, src_lengths, return_all_hiddens=return_all_hiddens
)
if self.use_cross_attentive_loss and src_txt_tokens is not None:
assert self.cross_attentive_loss_before_last_layer < len(
ret1["encoder_states"]
)
ret1 = self.process_attentive_loss_states(
ret1,
ret1["encoder_states"][
-self.cross_attentive_loss_before_last_layer - 1
],
)
if src_txt_tokens is not None:
ret2 = self.text_encoder(
src_txt_tokens, src_txt_lengths, return_all_hiddens=return_all_hiddens
)
if return_all_hiddens:
if self.cross_attentive_loss_before_last_layer == len(
self.text_encoder.layers
):
text_embedding, _ = self.text_encoder.forward_embedding(
src_txt_tokens
)
text_embedding = text_embedding.transpose(0, 1)
ret2 = self.process_attentive_loss_states(ret2, text_embedding)
else:
assert self.cross_attentive_loss_before_last_layer < len(
self.text_encoder.layers
)
ret2 = self.process_attentive_loss_states(
ret2,
ret2["encoder_states"][
-self.cross_attentive_loss_before_last_layer - 1
],
)
def merge_output(rst1, rst2):
if rst1 is None:
if not (self.enc2_along_grad_mult == 1.0 or self.training):
rst2 = self.mult_rst_grad(rst2, self.enc2_along_grad_mult)
return rst2
if rst2 is None:
return rst1
if self.enc_grad_mult != 1.0 and self.training:
rst1 = self.mult_rst_grad(rst1, self.enc_grad_mult)
rst2 = self.mult_rst_grad(rst2, self.enc_grad_mult)
rst = (rst1, rst2)
return rst
return merge_output(ret1, ret2)
def reorder_encoder_out(self, encoder_out, new_order):
assert self.training is False # used for inference only
return self.spch_encoder.reorder_encoder_out(encoder_out, new_order)
# TransformerMultiInputDecoder: take one or two encoder inputs
class TransformerMultiInputDecoder(FairseqDecoder):
def __init__(
self,
dictionary,
spch_decoder,
text_decoder,
compute_cross_attentive_loss=False,
cross_attentive_loss_with_norm=True,
cross_attentive_loss_reverse=False,
):
super().__init__(dictionary)
self.spch_decoder = spch_decoder
self.text_decoder = text_decoder
self.compute_cross_attentive_loss = compute_cross_attentive_loss
self.cross_attentive_loss_with_norm = cross_attentive_loss_with_norm
self.cross_attentive_loss_reverse = cross_attentive_loss_reverse
@classmethod
def share_spchdecoder(cls, task_args, text_decoder, spch_decoder):
if task_args.decoder_shared_layer_level == 0:
return text_decoder
assert text_decoder.embed_tokens == spch_decoder.embed_tokens
spch_decoder.project_in_dim = text_decoder.project_in_dim
spch_decoder.embed_positions = text_decoder.embed_positions
spch_decoder.layernorm_embedding = text_decoder.layernorm_embedding
spch_decoder.project_out_dim = text_decoder.project_out_dim
spch_decoder.adaptive_softmax = text_decoder.adaptive_softmax
if task_args.decoder_shared_layer_level == 1:
spch_decoder.output_projection = text_decoder.output_projection
spch_decoder.layer_norm = text_decoder.layer_norm
else: # 2
spch_decoder.output_projection.weight = (
text_decoder.output_projection.weight
)
for i, ly in enumerate(text_decoder.layers):
sly = spch_decoder.layers[i]
sly.self_attn = ly.self_attn
sly.self_attn_layer_norm = ly.self_attn_layer_norm
# sly.encoder_attn = ly.encoder_attn
if (
task_args.decoder_shared_layer_level == 1
): # share everything, but under different models
sly.encoder_attn = ly.encoder_attn
sly.encoder_attn_layer_norm = ly.encoder_attn_layer_norm
sly.fc1 = ly.fc1
sly.fc2 = ly.fc2
sly.final_layer_norm = ly.final_layer_norm
else: # task_args.decoder_shared_layer_level == 2: #separated encoder_attn_layer_norm and bias
sly.encoder_attn.k_proj.weight = ly.encoder_attn.k_proj.weight
sly.encoder_attn.v_proj.weight = ly.encoder_attn.v_proj.weight
sly.encoder_attn.q_proj.weight = ly.encoder_attn.q_proj.weight
sly.encoder_attn.out_proj.weight = ly.encoder_attn.out_proj.weight
sly.fc1.weight = ly.fc1.weight
sly.fc2.weight = ly.fc2.weight
return spch_decoder
def cross_attentive_loss(
self, teacher_states, student_states, teacher_masking, student_masking, eps=1e-6
):
x = teacher_states.transpose(0, 1) # from T X B X D to B X T X D
y = student_states.transpose(0, 1)
if self.cross_attentive_loss_with_norm:
x = x / (x.norm(dim=2, keepdim=True) + eps)
y = y / (y.norm(dim=2, keepdim=True) + eps)
dim = x.size(-1)
# lengths: batch X seqLen
sim_scores_xy = torch.bmm(x, y.transpose(1, 2)) # batch X lenx X leny ]
if y.dtype == torch.float16:
sim_scores_xy = sim_scores_xy.float()
y = y.float()
x = x.float()
if teacher_masking != []:
assert len(teacher_masking) == 1
sim_scores_xy = sim_scores_xy.masked_fill(
teacher_masking[0].unsqueeze(-1), float("-inf")
)
if student_masking != []:
sim_scores_xy = sim_scores_xy.masked_fill(
student_masking[0].unsqueeze(1), float("-inf")
)
# do masking
y_weights = utils.softmax(sim_scores_xy, dim=-1)
if teacher_masking != []:
y_weights = y_weights.masked_fill(teacher_masking[0].unsqueeze(-1), 0)
x_reconstruct_from_y = torch.bmm(y_weights, y)
sim_scores_xx = torch.bmm(x, x.transpose(1, 2)) # batch X lenx X lenx ]
x_weights = utils.softmax(sim_scores_xx, dim=-1)
if teacher_masking != []:
x_weights = x_weights.masked_fill(teacher_masking[0].unsqueeze(-1), 0)
# no gradient for teacher state
x_reconstruct_from_x = torch.bmm(x_weights, x).detach()
cost = (x_reconstruct_from_x - x_reconstruct_from_y).norm(dim=2)
if teacher_masking != []:
cost = cost.masked_fill(teacher_masking[0], 0)
if not self.cross_attentive_loss_with_norm:
cost = cost / dim
return cost
def forward(
self,
prev_output_tokens,
encoder_out,
incremental_state=None,
has_txt_input=False,
**kwargs
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for input feeding/teacher forcing. If there are
two or more input during training, they will share the same prev_output_tokens
encoder_out (tuple[Tensor]): output from the encoder, used for
encoder-side attention. It will be tuple if there are more inputs, but a tensor
if only one input
incremental_state ([dict]): dictionary used for storing state during
:ref:`Incremental decoding`. It is only valid for inference, only from single
input
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`. If there are N inputs, batch will be N bigger than a single input
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
"""
assert not isinstance(encoder_out, EncoderOut)
if isinstance(encoder_out, tuple): # training with mulitple input
rst = []
assert len(encoder_out) == 2
for i, eo in enumerate(encoder_out):
assert incremental_state is None
if i == 0:
rst.append(
self.spch_decoder(prev_output_tokens, eo, incremental_state)
)
else:
rst.append(
self.text_decoder(prev_output_tokens, eo, incremental_state)
)
dec_out = torch.cat([r[0] for r in rst], dim=0)
attn_cost = None
if self.compute_cross_attentive_loss:
assert isinstance(encoder_out[0], dict)
if self.cross_attentive_loss_reverse:
attn_cost = self.cross_attentive_loss(
teacher_states=encoder_out[1]["encoder_states"], # text_states
student_states=encoder_out[0]["encoder_states"], # spch_states
teacher_masking=encoder_out[1]["encoder_padding_mask"],
student_masking=encoder_out[0]["encoder_padding_mask"],
)
else:
attn_cost = self.cross_attentive_loss(
teacher_states=encoder_out[0]["encoder_states"], # spch_states
student_states=encoder_out[1]["encoder_states"], # text_states
teacher_masking=encoder_out[0]["encoder_padding_mask"],
student_masking=encoder_out[1]["encoder_padding_mask"],
)
return (dec_out, {"attn_cost": attn_cost})
else: # inference or training with one input
if has_txt_input:
return self.text_decoder(
prev_output_tokens, encoder_out, incremental_state
)
return self.spch_decoder(prev_output_tokens, encoder_out, incremental_state)
# Note:
# dual input transformer:
# encoder: S2TTransformerEncoder for speech + TransformerEncoder for text
# decoder: TransformerDecoder for text
@register_model("dual_input_s2t_transformer")
class DualInputS2TTransformerModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
self.num_updates = 0
def max_positions(self):
return None # it is provided in task
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# encoder 1: S2TTransformerEncoder for speech
parser.add_argument(
"--conv-kernel-sizes",
type=str,
metavar="N",
help="kernel sizes of Conv1d subsampling layers",
)
parser.add_argument(
"--conv-channels",
type=int,
metavar="N",
help="# of channels in Conv1d subsampling layers",
)
parser.add_argument(
"--enc-output-dim",
type=int,
metavar="N",
help="""
encoder output dimension, can be None. If specified, projecting the
transformer output to the specified dimension""",
)
# standard Transformer
parser.add_argument(
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-text-embed-dim",
type=int,
metavar="N",
help="encoder text embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
# non-standard transformer parameters
parser.add_argument(
"--speech-encoder-layers",
type=int,
metavar="N",
help="num speech encoder layers",
)
parser.add_argument(
"--text-encoder-layers",
type=int,
metavar="N",
help="num text encoder layers",
)
parser.add_argument(
"--encoder-shared-layers",
type=int,
metavar="N",
help="num shared encoder layers",
)
parser.add_argument(
"--encoder-shared-layer-level",
type=int,
metavar="N",
default=0,
choices=[0, 1, 2],
help="share layer level 0: all share 1: all share with separate model 2: share weight but not bias and layernorm",
)
parser.add_argument(
"--decoder-shared-layer-level",
default=0,
choices=[0, 1, 2],
type=int,
metavar="N",
help="0: share everything; 1: share everything with different model 2: no share layer_norm and bias",
)
###
parser.add_argument(
"--text-input-cost-ratio",
type=float,
default=1.0,
metavar="V",
help="text input cost ratio relative to speech input cost",
)
parser.add_argument(
"--init-scale",
type=float,
default=1.0,
metavar="V",
help="scale the initial weight by given factor",
)
parser.add_argument(
"--enc-grad-mult",
type=float,
metavar="V",
default=1.0,
help="multiply enc1 and enc2 gradient by V",
)
parser.add_argument(
"--enc2-along-grad-mult",
type=float,
metavar="V",
default=1.0,
help="multiply enc2 gradient by V if only enc2 is used",
)
parser.add_argument(
"--load-pretrain-encoder",
type=str,
default="",
metavar="EXPR",
help=""" path to the pretrained encoder """,
)
parser.add_argument(
"--load-pretrain-speech-encoder",
type=str,
default="",
metavar="EXPR",
help=""" path to the pretrained speech encoder """,
)
parser.add_argument(
"--load-pretrain-text-encoder",
type=str,
default="",
metavar="EXPR",
help=""" path to the pretrained text encoder """,
)
parser.add_argument(
"--load-pretrain-text-encoder-last",
type=str,
default="",
metavar="EXPR",
help=""" path to the pretrained text encoder """,
)
parser.add_argument(
"--load-pretrain-decoder",
type=str,
metavar="EXPR",
default="",
help=""" path to the pretrained encoder """,
)
parser.add_argument(
"--add-speech-eos",
action="store_true",
help="add eos token at the end of input feature",
)
parser.add_argument(
"--speech-encoder-adapter-type",
type=str,
metavar="EXPR",
default="None",
choices=["None", "Linear", "MLP"],
help="add speech encoder adapter",
)
@classmethod
def build_encoder(cls, args, task):
spch_encoder = DualInputEncoder.build_spch_encoder(args)
text_encoder = DualInputEncoder.build_text_encoder(
args, task.src_dict, spch_encoder
)
cross_attentive_loss_before_last_layer = (
0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
)
encoder = DualInputEncoder(
args,
spch_encoder,
text_encoder,
task.src_dict,
cross_attentive_loss_before_last_layer,
)
if args.init_scale != 1.0:
with torch.no_grad():
for param in encoder.parameters():
param.data.mul_(args.init_scale)
if args.load_pretrain_text_encoder != "":
checkpoint_utils.load_pretrained_component_from_model(
text_encoder, args.load_pretrain_text_encoder
)
if args.load_pretrain_speech_encoder != "":
if hasattr(spch_encoder, "encoder"):
checkpoint_utils.load_pretrained_component_from_model(
spch_encoder.encoder, args.load_pretrain_speech_encoder
)
else:
checkpoint_utils.load_pretrained_component_from_model(
spch_encoder, args.load_pretrain_speech_encoder
)
if (
args.load_pretrain_text_encoder_last != ""
): # if share encoder, speech encoder parameters will be used.
# It provides a chance to use pre-trained mt encoder instead
checkpoint_utils.load_pretrained_component_from_model(
text_encoder, args.load_pretrain_text_encoder_last
)
if args.load_pretrain_encoder != "":
checkpoint_utils.load_pretrained_component_from_model(
encoder, args.load_pretrain_encoder
)
return encoder
@classmethod
def build_decoder(cls, args, task):
dec_cfg = {
"decoder_layerdrop": args.decoder_layerdrop,
"share_decoder_input_output_embed": args.share_decoder_input_output_embed,
"decoder_embed_dim": args.decoder_embed_dim,
"max_target_positions": args.max_target_positions,
"dropout": args.dropout,
"encoder_learned_pos": args.encoder_learned_pos,
"decoder_learned_pos": args.decoder_learned_pos,
"layernorm_embedding": args.layernorm_embedding,
"decoder_normalize_before": args.decoder_normalize_before,
"activation_dropout": args.activation_dropout,
"attention_dropout": args.attention_dropout,
"decoder_ffn_embed_dim": args.decoder_ffn_embed_dim,
"decoder_layers": args.decoder_layers,
"decoder_attention_heads": args.decoder_attention_heads,
"decoder_output_dim": args.decoder_embed_dim,
"no_scale_embedding": args.no_scale_embedding,
"adaptive_input": args.adaptive_input,
"quant_noise_pq": args.quant_noise_pq,
"adaptive_softmax_cutoff": args.adaptive_softmax_cutoff,
"tie_adaptive_weights": args.tie_adaptive_weights,
"no_token_positional_embeddings": args.no_token_positional_embeddings,
}
dec_cfg = namedtuple("args", dec_cfg.keys())(*dec_cfg.values())
dec_emb = nn.Embedding(
len(task.target_dictionary),
args.decoder_embed_dim,
task.target_dictionary.pad(),
)
compute_cross_attentive_loss = (
True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False
)
cross_attentive_loss_without_norm = getattr(
args, "attentive_cost_without_normalize", False
)
cross_attentive_loss_reverse = (
False # getattr(args, "attentive_cost_reverse", False)
)
text_decoder = TransformerDecoder(dec_cfg, task.target_dictionary, dec_emb)
spch_decoder = TransformerDecoder(dec_cfg, task.target_dictionary, dec_emb)
spch_decoder = TransformerMultiInputDecoder.share_spchdecoder(
args, text_decoder, spch_decoder
)
decoder = TransformerMultiInputDecoder(
dictionary=task.target_dictionary,
spch_decoder=spch_decoder,
text_decoder=text_decoder,
compute_cross_attentive_loss=compute_cross_attentive_loss,
cross_attentive_loss_with_norm=True
if not cross_attentive_loss_without_norm
else False,
cross_attentive_loss_reverse=cross_attentive_loss_reverse,
)
if args.init_scale != 1.0:
with torch.no_grad():
for param in decoder.parameters():
param.data.mul_(args.init_scale)
if args.load_pretrain_decoder != "":
try:
checkpoint_utils.load_pretrained_component_from_model(
decoder, args.load_pretrain_decoder
)
except RuntimeError:
checkpoint_utils.load_pretrained_component_from_model(
decoder.text_decoder, args.load_pretrain_decoder
)
if args.decoder_shared_layer_level > 0:
checkpoint_utils.load_pretrained_component_from_model(
decoder.spch_decoder, args.load_pretrain_decoder
)
return decoder
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure that all args are properly defaulted
# (in case there are any new ones)
dualinputs2ttransformer_base(args)
encoder = cls.build_encoder(args, task)
decoder = cls.build_decoder(args, task)
return cls(encoder, decoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
super().set_num_updates(num_updates)
self.num_updates = num_updates
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens,
use_encoder_outputs=False,
src_txt_tokens=None,
src_txt_lengths=None,
mode="sup_speech",
**kwargs
):
"""
Run the forward pass for an encoder-decoder model.
First feed a batch of source tokens through the encoder. Then, feed the
encoder output and previous decoder outputs (i.e., teacher forcing) to
the decoder to produce the next outputs::
encoder_out = self.encoder(src_tokens, src_lengths)
return self.decoder(prev_output_tokens, encoder_out)
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
mode = 'sup_speech' or 'text'
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
if mode == "text":
assert src_txt_tokens is None
src_txt_tokens = src_tokens
src_txt_lengths = src_lengths
src_tokens = None
src_lengths = None
encoder_out = self.encoder(
src_tokens,
src_lengths=src_lengths,
src_txt_tokens=src_txt_tokens,
src_txt_lengths=src_txt_lengths,
**kwargs
)
has_txt_input = True if src_txt_tokens is not None else False
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
has_txt_input=has_txt_input,
**kwargs
)
if use_encoder_outputs:
return decoder_out, encoder_out
return decoder_out
@register_model_architecture(
"dual_input_s2t_transformer", "dualinputs2ttransformer_base"
)
def dualinputs2ttransformer_base(args):
args.encoder_freezing_updates = getattr(args, "encoder_freezing_updates", 0)
# Convolutional subsampler
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024)
# Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_text_embed_dim = getattr(
args, "encoder_text_embed_dim", args.encoder_embed_dim
)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 10)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.encoder_shared_layers = getattr(args, "encoder_shared_layers", 0)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.add_speech_eos = getattr(args, "add_speech_eos", False)
@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_s")
def dualinputs2ttransformer_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 7)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 7)
args.decoder_layers = getattr(args, "decoder_layers", 7)
dualinputs2ttransformer_base(args)
@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_m")
def dualinputs2ttransformer_m(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.dropout = getattr(args, "dropout", 0.15)
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 10)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.decoder_layers = getattr(args, "decoder_layers", 6)
dualinputs2ttransformer_base(args)
@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_b")
def dualinputs2ttransformer_b(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
args.dropout = getattr(args, "dropout", 0.15)
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.decoder_layers = getattr(args, "decoder_layers", 6)
dualinputs2ttransformer_base(args)
@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_l")
def dualinputs2ttransformer_l(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.2)
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.decoder_layers = getattr(args, "decoder_layers", 6)
dualinputs2ttransformer_base(args)