Spaces:
Runtime error
Runtime error
OFA-OCR-dedao-demo001
/
fairseq
/examples
/speech_text_joint_to_text
/models
/s2t_dualinputxmtransformer.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import copy | |
import torch.nn as nn | |
from fairseq import checkpoint_utils | |
from fairseq import utils | |
from fairseq.data.data_utils import lengths_to_padding_mask | |
from fairseq.models import ( | |
register_model, | |
register_model_architecture, | |
FairseqEncoder, | |
) | |
from fairseq.models.speech_to_text import XMTransformerModel, Wav2VecEncoderWithAdaptor | |
from fairseq.models.speech_to_text.xm_transformer import ( | |
set_default_adaptor_args, | |
set_default_w2v_encoder_args, | |
) | |
from fairseq.models.transformer import TransformerEncoder, TransformerDecoder | |
from fairseq.models.wav2vec import TransformerSentenceEncoderLayer | |
from fairseq.utils import safe_hasattr | |
from .s2t_dualinputtransformer import ( | |
DualInputS2TTransformerModel, | |
TransformerMultiInputDecoder, | |
DualInputEncoder, | |
) | |
class TransformerSentenceEncoderLayerStd(TransformerSentenceEncoderLayer): | |
def __init__(self, sent_enc_layer): | |
super(TransformerSentenceEncoderLayer, self).__init__() | |
self.embedding_dim = sent_enc_layer.embedding_dim | |
self.dropout = sent_enc_layer.dropout | |
self.activation_dropout = sent_enc_layer.activation_dropout | |
# Initialize blocks | |
self.activation_fn = sent_enc_layer.activation_fn | |
self.self_attn = sent_enc_layer.self_attn | |
self.dropout1 = sent_enc_layer.dropout1 | |
self.dropout2 = sent_enc_layer.dropout2 | |
self.dropout3 = sent_enc_layer.dropout3 | |
self.layer_norm_first = sent_enc_layer.layer_norm_first | |
# layer norm associated with the self attention layer | |
self.self_attn_layer_norm = sent_enc_layer.self_attn_layer_norm | |
self.fc1 = sent_enc_layer.fc1 | |
self.fc2 = sent_enc_layer.fc2 | |
# layer norm associated with the position wise feed-forward NN | |
self.final_layer_norm = sent_enc_layer.final_layer_norm | |
def forward( | |
self, | |
x, | |
self_attn_mask=None, | |
self_attn_padding_mask=None, | |
need_weights=None, | |
att_args=None, | |
): | |
x, attn = super().forward( | |
x, self_attn_mask, self_attn_padding_mask, need_weights, att_args | |
) | |
return x | |
# TODO retire SharedEncoder | |
class SharedEncoder(FairseqEncoder): | |
def __init__(self, wav2vec_enc, mbart_enc, adaptor, shared_layers): | |
super().__init__(None) | |
self.w2v_encoder = wav2vec_enc | |
self.shared_layers = self.w2v_encoder.w2v_model.encoder.layers[-shared_layers:] | |
self.w2v_encoder.w2v_model.encoder.layers = ( | |
self.w2v_encoder.w2v_model.encoder.layers[:-shared_layers] | |
) | |
self.adaptor = adaptor | |
if self.shared_layers[-1].layer_norm_first: | |
self.final_layer_norm = mbart_enc.layer_norm | |
else: | |
mbart_enc.layer_norm = None | |
self.final_layer_norm = None | |
shared_layer_from = len(mbart_enc.layers) - shared_layers | |
if shared_layer_from < 0: | |
shared_layer_from = 0 | |
for layer_id, layer in enumerate(self.shared_layers): | |
mbart_enc.layers[ | |
shared_layer_from + layer_id | |
] = TransformerSentenceEncoderLayerStd(layer) | |
def forward(self, src_tokens, src_lengths=None, **kwargs): | |
padding_mask = lengths_to_padding_mask(src_lengths) | |
if not padding_mask.any(): | |
padding_mask = None | |
out = self.w2v_encoder.forward(src_tokens, padding_mask, tbc=True) | |
x = out["encoder_out"] | |
enc_padding_mask = None | |
if out["encoder_padding_mask"] is not None: | |
enc_padding_mask = out["encoder_padding_mask"].transpose( | |
0, 1 | |
) # T X B --> B X T | |
x, enc_padding_mask = self.adaptor(x, enc_padding_mask) | |
for layer in self.shared_layers: | |
x, _ = layer(x, enc_padding_mask) | |
if self.final_layer_norm is not None: | |
x = self.final_layer_norm(x) | |
return { | |
"encoder_out": [x], # T x B x C | |
"encoder_padding_mask": [enc_padding_mask] | |
if enc_padding_mask is not None | |
else [], # B x T | |
"encoder_embedding": [], # B x T x C | |
"encoder_states": [], # List[T x B x C] | |
"src_tokens": [], | |
"src_lengths": [], | |
} | |
class StackedWav2VecEncoderWithAdaptor(FairseqEncoder): | |
def __init__( | |
self, | |
wav2vec_enc, | |
mbart_enc_layers, | |
mbart_layer_norm, | |
adaptor, | |
drop_w2v_layers=0, | |
): | |
super().__init__(None) | |
self.w2v_encoder = wav2vec_enc | |
self.adaptor = adaptor | |
self.mbart_encoder_layers = mbart_enc_layers | |
self.final_layer_norm = mbart_layer_norm | |
if drop_w2v_layers > 0: | |
self.w2v_encoder.w2v_model.encoder.layers = ( | |
self.w2v_encoder.w2v_model.encoder.layers[:-drop_w2v_layers] | |
) | |
def forward(self, src_tokens, src_lengths=None, return_all_hiddens=False, **kwargs): | |
padding_mask = lengths_to_padding_mask(src_lengths) | |
if not padding_mask.any(): | |
padding_mask = None | |
out = self.w2v_encoder.forward(src_tokens, padding_mask, tbc=True) | |
x = out["encoder_out"] | |
enc_padding_mask = None | |
if out["encoder_padding_mask"] is not None: | |
enc_padding_mask = out["encoder_padding_mask"].transpose( | |
0, 1 | |
) # T X B --> B X T | |
x, enc_padding_mask = self.adaptor(x, enc_padding_mask) | |
encoder_states = [] | |
for layer in self.mbart_encoder_layers: | |
x = layer(x, enc_padding_mask) | |
if return_all_hiddens: | |
encoder_states.append(x) | |
if self.final_layer_norm is not None: | |
x = self.final_layer_norm(x) | |
return { | |
"encoder_out": [x], # T x B x C | |
"encoder_padding_mask": [enc_padding_mask] | |
if enc_padding_mask is not None | |
else [], # B x T | |
"encoder_embedding": [], # B x T x C | |
"encoder_states": encoder_states, # List[T x B x C] | |
"src_tokens": [], | |
"src_lengths": [], | |
} | |
def reorder_encoder_out(self, encoder_out, new_order): | |
new_encoder_out = ( | |
[] | |
if len(encoder_out["encoder_out"]) == 0 | |
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] | |
) | |
new_encoder_padding_mask = ( | |
[] | |
if len(encoder_out["encoder_padding_mask"]) == 0 | |
else [ | |
x.index_select(0, new_order) | |
for x in encoder_out["encoder_padding_mask"] | |
] | |
) | |
new_encoder_embedding = ( | |
[] | |
if len(encoder_out["encoder_embedding"]) == 0 | |
else [ | |
x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] | |
] | |
) | |
encoder_states = encoder_out["encoder_states"] | |
if len(encoder_states) > 0: | |
for idx, state in enumerate(encoder_states): | |
encoder_states[idx] = state.index_select(1, new_order) | |
return { | |
"encoder_out": new_encoder_out, # T x B x C | |
"encoder_padding_mask": new_encoder_padding_mask, # B x T | |
"encoder_embedding": new_encoder_embedding, # B x T x C | |
"encoder_states": encoder_states, # List[T x B x C] | |
"src_tokens": [], # B x T | |
"src_lengths": [], # B x 1 | |
} | |
# Note: | |
# dual input transformer: | |
# encoder: wav2vec for speech + mbart encoder for text | |
# decoder: mbart decoder for text | |
class DualInputXMTransformerModel(DualInputS2TTransformerModel): | |
def __init__(self, encoder, decoder): | |
super().__init__(encoder, decoder) | |
def add_args(parser): | |
"""Add model-specific arguments to the parser.""" | |
# wav2vec encoder | |
Wav2VecEncoderWithAdaptor.add_args(parser) | |
# add_decoder_args(parser) | |
# mbart Transformer | |
parser.add_argument( | |
"--activation-fn", | |
type=str, | |
default="relu", | |
choices=utils.get_available_activation_fns(), | |
help="activation function to use", | |
) | |
parser.add_argument( | |
"--mbart-dropout", type=float, metavar="D", help="dropout probability" | |
) | |
parser.add_argument( | |
"--mbart-attention-dropout", | |
type=float, | |
metavar="D", | |
help="dropout probability for attention weights", | |
) | |
parser.add_argument( | |
"--mbart-activation-dropout", | |
type=float, | |
metavar="D", | |
help="dropout probability after activation in FFN.", | |
) | |
parser.add_argument( | |
"--encoder-embed-dim", | |
type=int, | |
metavar="N", | |
help="encoder embedding dimension", | |
) | |
parser.add_argument( | |
"--encoder-ffn-embed-dim", | |
type=int, | |
metavar="N", | |
help="encoder embedding dimension for FFN", | |
) | |
parser.add_argument( | |
"--encoder-layers", type=int, metavar="N", help="num encoder layers" | |
) | |
parser.add_argument( | |
"--encoder-attention-heads", | |
type=int, | |
metavar="N", | |
help="num encoder attention heads", | |
) | |
parser.add_argument( | |
"--encoder-normalize-before", | |
action="store_true", | |
help="apply layernorm before each encoder block", | |
) | |
parser.add_argument( | |
"--decoder-embed-dim", | |
type=int, | |
metavar="N", | |
help="decoder embedding dimension", | |
) | |
parser.add_argument( | |
"--decoder-ffn-embed-dim", | |
type=int, | |
metavar="N", | |
help="decoder embedding dimension for FFN", | |
) | |
parser.add_argument( | |
"--decoder-layers", type=int, metavar="N", help="num decoder layers" | |
) | |
parser.add_argument( | |
"--decoder-attention-heads", | |
type=int, | |
metavar="N", | |
help="num decoder attention heads", | |
) | |
parser.add_argument( | |
"--decoder-normalize-before", | |
action="store_true", | |
help="apply layernorm before each decoder block", | |
) | |
parser.add_argument( | |
"--layernorm-embedding", | |
action="store_true", | |
help="add layernorm to embedding", | |
) | |
parser.add_argument( | |
"--no-scale-embedding", | |
action="store_true", | |
help="if True, dont scale embeddings", | |
) | |
parser.add_argument( | |
"--load-pretrained-mbart-from", | |
type=str, | |
metavar="STR", | |
help="model to take text encoder decoder weights from (for initialization)", | |
) | |
# parser.add_argument("--finetune-w2v-params", type=str, metavar="STR", | |
# help="comma-separated param strings to finetune.") | |
parser.add_argument( | |
"--finetune-mbart-decoder-params", | |
type=str, | |
metavar="STR", | |
help="comma-separated param strings to finetune.", | |
) | |
parser.add_argument( | |
"--finetune-mbart-encoder-params", | |
type=str, | |
metavar="STR", | |
help="comma-separated param strings to finetune.", | |
) | |
parser.add_argument( | |
"--skip-encoder-projection", | |
action="store_true", | |
help="skip the projection layer in encoder", | |
) | |
parser.add_argument( | |
"--enc-grad-mult", | |
type=float, | |
metavar="V", | |
default=1.0, | |
help="multiply enc1 and enc2 gradient by V", | |
) | |
parser.add_argument( | |
"--enc2-along-grad-mult", | |
type=float, | |
metavar="V", | |
default=1.0, | |
help="multiply enc2 gradient by V if only enc2 is used", | |
) | |
parser.add_argument( | |
"--text-input-cost-ratio", | |
type=float, | |
default=1.0, | |
metavar="V", | |
help="text input cost ratio relative to speech input cost", | |
) | |
parser.add_argument( | |
"--stack-w2v-mbart-encoder", | |
action="store_true", | |
help="stack w2v and mbart encoder", | |
) | |
parser.add_argument( | |
"--stack-w2v-mbart-nonorm-encoder", | |
action="store_true", | |
help="stack w2v and mbart encoder", | |
) | |
parser.add_argument( | |
"--no-final-norm-decoder", action="store_true", help="no layer norm" | |
) | |
parser.add_argument( | |
"--drop-w2v-layers", | |
type=int, | |
default=0, | |
metavar="N", | |
help="drop w2v encoder layers", | |
) | |
parser.add_argument( | |
"--share-w2v-text-encoder", | |
action="store_true", | |
help="share w2v encoder layers with text encoder", | |
) | |
parser.add_argument( | |
"--shared-w2v-layers", | |
type=int, | |
default=0, | |
metavar="N", | |
help="shared encoder layers from w2v encoder", | |
) | |
def build_encoder(cls, args, task): | |
_args = copy.deepcopy(args) | |
_args.dropout = args.mbart_dropout | |
_args.attention_dropout = args.mbart_attention_dropout | |
_args.activation_dropout = args.mbart_activation_dropout | |
_args.max_source_positions = 1024 | |
enc_emb = nn.Embedding( | |
len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad() | |
) | |
text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb) | |
spch_encoder = Wav2VecEncoderWithAdaptor(args) | |
if getattr(args, "load_pretrained_mbart_from", None): | |
text_encoder = checkpoint_utils.load_pretrained_component_from_model( | |
component=text_encoder, checkpoint=args.load_pretrained_mbart_from | |
) | |
if getattr(args, "stack_w2v_mbart_encoder", False): | |
assert getattr(args, "share_w2v_text_encoder", False) is False | |
spch_encoder = StackedWav2VecEncoderWithAdaptor( | |
spch_encoder.w2v_encoder, | |
text_encoder.layers, | |
text_encoder.layer_norm, | |
spch_encoder.adaptor, | |
args.drop_w2v_layers, | |
) | |
elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False): | |
text_encoder.layer_norm = None | |
spch_encoder = StackedWav2VecEncoderWithAdaptor( | |
spch_encoder.w2v_encoder, | |
text_encoder.layers, | |
text_encoder.layer_norm, | |
spch_encoder.adaptor, | |
args.drop_w2v_layers, | |
) | |
elif getattr(args, "share_w2v_text_encoder", False): | |
spch_encoder = SharedEncoder( | |
spch_encoder.w2v_encoder, | |
text_encoder, | |
spch_encoder.adaptor, | |
args.shared_w2v_layers, | |
) | |
for k, p in spch_encoder.named_parameters(): | |
# Freeze pretrained models by default | |
if safe_hasattr( | |
args, "finetune_w2v_params" | |
) and XMTransformerModel.finetune_params(args.finetune_w2v_params, k): | |
p.requires_grad = True | |
else: | |
p.requires_grad = False | |
for k, p in text_encoder.named_parameters(): | |
# Freeze pretrained models by default | |
if safe_hasattr( | |
args, "finetune_mbart_encoder_params" | |
) and XMTransformerModel.finetune_params( | |
args.finetune_mbart_encoder_params, k | |
): | |
p.requires_grad = True | |
else: | |
p.requires_grad = False | |
cross_attentive_loss_before_last_layer = ( | |
0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1 | |
) | |
encoder = DualInputEncoder( | |
args, | |
spch_encoder, | |
text_encoder, | |
task.src_dict, | |
cross_attentive_loss_before_last_layer, | |
) | |
return encoder | |
def build_decoder(cls, args, task): | |
_args = copy.deepcopy(args) | |
_args.dropout = args.mbart_dropout | |
_args.attention_dropout = args.mbart_attention_dropout | |
_args.activation_dropout = args.mbart_activation_dropout | |
_args.max_target_positions = 1024 | |
dec_emb = nn.Embedding( | |
len(task.tgt_dict), _args.encoder_embed_dim, task.tgt_dict.pad() | |
) | |
decoder = TransformerDecoder(_args, task.tgt_dict, dec_emb) | |
if getattr(args, "load_pretrained_mbart_from", None): | |
decoder = checkpoint_utils.load_pretrained_component_from_model( | |
component=decoder, checkpoint=args.load_pretrained_mbart_from | |
) | |
if getattr(args, "no_final_norm_decoder", False): | |
decoder.layer_norm = None | |
for k, p in decoder.named_parameters(): | |
# Freeze pretrained models by default | |
if safe_hasattr( | |
args, "finetune_mbart_decoder_params" | |
) and XMTransformerModel.finetune_params( | |
args.finetune_mbart_decoder_params, k | |
): | |
p.requires_grad = True | |
else: | |
p.requires_grad = False | |
compute_cross_attentive_loss = ( | |
True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False | |
) | |
cross_attentive_loss_without_norm = getattr( | |
args, "attentive_cost_without_normalize", False | |
) | |
cross_attentive_loss_reverse = ( | |
False # getattr(args, "attentive_cost_reverse", False) | |
) | |
decoder = TransformerMultiInputDecoder( | |
dictionary=task.target_dictionary, | |
spch_decoder=decoder, | |
text_decoder=decoder, | |
compute_cross_attentive_loss=compute_cross_attentive_loss, | |
cross_attentive_loss_with_norm=True | |
if not cross_attentive_loss_without_norm | |
else False, | |
cross_attentive_loss_reverse=cross_attentive_loss_reverse, | |
) | |
return decoder | |
def build_model(cls, args, task): | |
"""Build a new model instance.""" | |
# make sure that all args are properly defaulted | |
# (in case there are any new ones) | |
dualinputxmtransformer_base(args) | |
encoder = cls.build_encoder(args, task) | |
decoder = cls.build_decoder(args, task) | |
return cls(encoder, decoder) | |
def dualinputxmtransformer_base(args): | |
# wav2vec encoder | |
set_default_w2v_encoder_args(args) | |
set_default_adaptor_args(args) | |
# mbart model | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) | |
args.encoder_ffn_embed_dim = getattr( | |
args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim | |
) | |
args.encoder_layers = getattr(args, "encoder_layers", 12) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) | |
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) | |
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) | |
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) | |
args.decoder_embed_path = getattr(args, "decoder_embed_path", None) | |
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) | |
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * 1024) | |
args.decoder_layers = getattr(args, "decoder_layers", 12) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) | |
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) | |
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) | |
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) | |
args.adaptive_input = getattr(args, "adaptive_input", False) | |
args.mbart_attention_dropout = getattr(args, "mbart_attention_dropout", 0.0) | |
args.mbart_activation_dropout = getattr(args, "mbart_activation_dropout", 0.0) | |
args.mbart_dropout = getattr(args, "mbart_dropout", 0.1) | |
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) | |
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) | |
args.share_decoder_input_output_embed = getattr( | |
args, "share_decoder_input_output_embed", True | |
) | |
args.no_token_positional_embeddings = getattr( | |
args, "no_token_positional_embeddings", False | |
) | |
args.decoder_output_dim = getattr( | |
args, "decoder_output_dim", args.decoder_embed_dim | |
) | |
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) | |
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) | |
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) | |
args.layernorm_embedding = getattr(args, "layernorm_embedding", True) | |
args.activation_fn = getattr(args, "activation_fn", "gelu") | |
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") | |
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) | |