Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import math | |
from collections.abc import Iterable | |
import torch | |
import torch.nn as nn | |
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask | |
from fairseq import utils | |
from fairseq.models import ( | |
FairseqEncoder, | |
FairseqEncoderDecoderModel, | |
FairseqEncoderModel, | |
FairseqIncrementalDecoder, | |
register_model, | |
register_model_architecture, | |
) | |
from fairseq.modules import ( | |
LinearizedConvolution, | |
TransformerDecoderLayer, | |
TransformerEncoderLayer, | |
VGGBlock, | |
) | |
class VGGTransformerModel(FairseqEncoderDecoderModel): | |
""" | |
Transformers with convolutional context for ASR | |
https://arxiv.org/abs/1904.11660 | |
""" | |
def __init__(self, encoder, decoder): | |
super().__init__(encoder, decoder) | |
def add_args(parser): | |
"""Add model-specific arguments to the parser.""" | |
parser.add_argument( | |
"--input-feat-per-channel", | |
type=int, | |
metavar="N", | |
help="encoder input dimension per input channel", | |
) | |
parser.add_argument( | |
"--vggblock-enc-config", | |
type=str, | |
metavar="EXPR", | |
help=""" | |
an array of tuples each containing the configuration of one vggblock: | |
[(out_channels, | |
conv_kernel_size, | |
pooling_kernel_size, | |
num_conv_layers, | |
use_layer_norm), ...]) | |
""", | |
) | |
parser.add_argument( | |
"--transformer-enc-config", | |
type=str, | |
metavar="EXPR", | |
help="""" | |
a tuple containing the configuration of the encoder transformer layers | |
configurations: | |
[(input_dim, | |
num_heads, | |
ffn_dim, | |
normalize_before, | |
dropout, | |
attention_dropout, | |
relu_dropout), ...]') | |
""", | |
) | |
parser.add_argument( | |
"--enc-output-dim", | |
type=int, | |
metavar="N", | |
help=""" | |
encoder output dimension, can be None. If specified, projecting the | |
transformer output to the specified dimension""", | |
) | |
parser.add_argument( | |
"--in-channels", | |
type=int, | |
metavar="N", | |
help="number of encoder input channels", | |
) | |
parser.add_argument( | |
"--tgt-embed-dim", | |
type=int, | |
metavar="N", | |
help="embedding dimension of the decoder target tokens", | |
) | |
parser.add_argument( | |
"--transformer-dec-config", | |
type=str, | |
metavar="EXPR", | |
help=""" | |
a tuple containing the configuration of the decoder transformer layers | |
configurations: | |
[(input_dim, | |
num_heads, | |
ffn_dim, | |
normalize_before, | |
dropout, | |
attention_dropout, | |
relu_dropout), ...] | |
""", | |
) | |
parser.add_argument( | |
"--conv-dec-config", | |
type=str, | |
metavar="EXPR", | |
help=""" | |
an array of tuples for the decoder 1-D convolution config | |
[(out_channels, conv_kernel_size, use_layer_norm), ...]""", | |
) | |
def build_encoder(cls, args, task): | |
return VGGTransformerEncoder( | |
input_feat_per_channel=args.input_feat_per_channel, | |
vggblock_config=eval(args.vggblock_enc_config), | |
transformer_config=eval(args.transformer_enc_config), | |
encoder_output_dim=args.enc_output_dim, | |
in_channels=args.in_channels, | |
) | |
def build_decoder(cls, args, task): | |
return TransformerDecoder( | |
dictionary=task.target_dictionary, | |
embed_dim=args.tgt_embed_dim, | |
transformer_config=eval(args.transformer_dec_config), | |
conv_config=eval(args.conv_dec_config), | |
encoder_output_dim=args.enc_output_dim, | |
) | |
def build_model(cls, args, task): | |
"""Build a new model instance.""" | |
# make sure that all args are properly defaulted | |
# (in case there are any new ones) | |
base_architecture(args) | |
encoder = cls.build_encoder(args, task) | |
decoder = cls.build_decoder(args, task) | |
return cls(encoder, decoder) | |
def get_normalized_probs(self, net_output, log_probs, sample=None): | |
# net_output['encoder_out'] is a (B, T, D) tensor | |
lprobs = super().get_normalized_probs(net_output, log_probs, sample) | |
lprobs.batch_first = True | |
return lprobs | |
DEFAULT_ENC_VGGBLOCK_CONFIG = ((32, 3, 2, 2, False),) * 2 | |
DEFAULT_ENC_TRANSFORMER_CONFIG = ((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2 | |
# 256: embedding dimension | |
# 4: number of heads | |
# 1024: FFN | |
# True: apply layerNorm before (dropout + resiaul) instead of after | |
# 0.2 (dropout): dropout after MultiheadAttention and second FC | |
# 0.2 (attention_dropout): dropout in MultiheadAttention | |
# 0.2 (relu_dropout): dropout after ReLu | |
DEFAULT_DEC_TRANSFORMER_CONFIG = ((256, 2, 1024, True, 0.2, 0.2, 0.2),) * 2 | |
DEFAULT_DEC_CONV_CONFIG = ((256, 3, True),) * 2 | |
# TODO: repace transformer encoder config from one liner | |
# to explicit args to get rid of this transformation | |
def prepare_transformer_encoder_params( | |
input_dim, | |
num_heads, | |
ffn_dim, | |
normalize_before, | |
dropout, | |
attention_dropout, | |
relu_dropout, | |
): | |
args = argparse.Namespace() | |
args.encoder_embed_dim = input_dim | |
args.encoder_attention_heads = num_heads | |
args.attention_dropout = attention_dropout | |
args.dropout = dropout | |
args.activation_dropout = relu_dropout | |
args.encoder_normalize_before = normalize_before | |
args.encoder_ffn_embed_dim = ffn_dim | |
return args | |
def prepare_transformer_decoder_params( | |
input_dim, | |
num_heads, | |
ffn_dim, | |
normalize_before, | |
dropout, | |
attention_dropout, | |
relu_dropout, | |
): | |
args = argparse.Namespace() | |
args.encoder_embed_dim = None | |
args.decoder_embed_dim = input_dim | |
args.decoder_attention_heads = num_heads | |
args.attention_dropout = attention_dropout | |
args.dropout = dropout | |
args.activation_dropout = relu_dropout | |
args.decoder_normalize_before = normalize_before | |
args.decoder_ffn_embed_dim = ffn_dim | |
return args | |
class VGGTransformerEncoder(FairseqEncoder): | |
"""VGG + Transformer encoder""" | |
def __init__( | |
self, | |
input_feat_per_channel, | |
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG, | |
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, | |
encoder_output_dim=512, | |
in_channels=1, | |
transformer_context=None, | |
transformer_sampling=None, | |
): | |
"""constructor for VGGTransformerEncoder | |
Args: | |
- input_feat_per_channel: feature dim (not including stacked, | |
just base feature) | |
- in_channel: # input channels (e.g., if stack 8 feature vector | |
together, this is 8) | |
- vggblock_config: configuration of vggblock, see comments on | |
DEFAULT_ENC_VGGBLOCK_CONFIG | |
- transformer_config: configuration of transformer layer, see comments | |
on DEFAULT_ENC_TRANSFORMER_CONFIG | |
- encoder_output_dim: final transformer output embedding dimension | |
- transformer_context: (left, right) if set, self-attention will be focused | |
on (t-left, t+right) | |
- transformer_sampling: an iterable of int, must match with | |
len(transformer_config), transformer_sampling[i] indicates sampling | |
factor for i-th transformer layer, after multihead att and feedfoward | |
part | |
""" | |
super().__init__(None) | |
self.num_vggblocks = 0 | |
if vggblock_config is not None: | |
if not isinstance(vggblock_config, Iterable): | |
raise ValueError("vggblock_config is not iterable") | |
self.num_vggblocks = len(vggblock_config) | |
self.conv_layers = nn.ModuleList() | |
self.in_channels = in_channels | |
self.input_dim = input_feat_per_channel | |
self.pooling_kernel_sizes = [] | |
if vggblock_config is not None: | |
for _, config in enumerate(vggblock_config): | |
( | |
out_channels, | |
conv_kernel_size, | |
pooling_kernel_size, | |
num_conv_layers, | |
layer_norm, | |
) = config | |
self.conv_layers.append( | |
VGGBlock( | |
in_channels, | |
out_channels, | |
conv_kernel_size, | |
pooling_kernel_size, | |
num_conv_layers, | |
input_dim=input_feat_per_channel, | |
layer_norm=layer_norm, | |
) | |
) | |
self.pooling_kernel_sizes.append(pooling_kernel_size) | |
in_channels = out_channels | |
input_feat_per_channel = self.conv_layers[-1].output_dim | |
transformer_input_dim = self.infer_conv_output_dim( | |
self.in_channels, self.input_dim | |
) | |
# transformer_input_dim is the output dimension of VGG part | |
self.validate_transformer_config(transformer_config) | |
self.transformer_context = self.parse_transformer_context(transformer_context) | |
self.transformer_sampling = self.parse_transformer_sampling( | |
transformer_sampling, len(transformer_config) | |
) | |
self.transformer_layers = nn.ModuleList() | |
if transformer_input_dim != transformer_config[0][0]: | |
self.transformer_layers.append( | |
Linear(transformer_input_dim, transformer_config[0][0]) | |
) | |
self.transformer_layers.append( | |
TransformerEncoderLayer( | |
prepare_transformer_encoder_params(*transformer_config[0]) | |
) | |
) | |
for i in range(1, len(transformer_config)): | |
if transformer_config[i - 1][0] != transformer_config[i][0]: | |
self.transformer_layers.append( | |
Linear(transformer_config[i - 1][0], transformer_config[i][0]) | |
) | |
self.transformer_layers.append( | |
TransformerEncoderLayer( | |
prepare_transformer_encoder_params(*transformer_config[i]) | |
) | |
) | |
self.encoder_output_dim = encoder_output_dim | |
self.transformer_layers.extend( | |
[ | |
Linear(transformer_config[-1][0], encoder_output_dim), | |
LayerNorm(encoder_output_dim), | |
] | |
) | |
def forward(self, src_tokens, src_lengths, **kwargs): | |
""" | |
src_tokens: padded tensor (B, T, C * feat) | |
src_lengths: tensor of original lengths of input utterances (B,) | |
""" | |
bsz, max_seq_len, _ = src_tokens.size() | |
x = src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) | |
x = x.transpose(1, 2).contiguous() | |
# (B, C, T, feat) | |
for layer_idx in range(len(self.conv_layers)): | |
x = self.conv_layers[layer_idx](x) | |
bsz, _, output_seq_len, _ = x.size() | |
# (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> (T, B, C * feat) | |
x = x.transpose(1, 2).transpose(0, 1) | |
x = x.contiguous().view(output_seq_len, bsz, -1) | |
input_lengths = src_lengths.clone() | |
for s in self.pooling_kernel_sizes: | |
input_lengths = (input_lengths.float() / s).ceil().long() | |
encoder_padding_mask, _ = lengths_to_encoder_padding_mask( | |
input_lengths, batch_first=True | |
) | |
if not encoder_padding_mask.any(): | |
encoder_padding_mask = None | |
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) | |
attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor) | |
transformer_layer_idx = 0 | |
for layer_idx in range(len(self.transformer_layers)): | |
if isinstance(self.transformer_layers[layer_idx], TransformerEncoderLayer): | |
x = self.transformer_layers[layer_idx]( | |
x, encoder_padding_mask, attn_mask | |
) | |
if self.transformer_sampling[transformer_layer_idx] != 1: | |
sampling_factor = self.transformer_sampling[transformer_layer_idx] | |
x, encoder_padding_mask, attn_mask = self.slice( | |
x, encoder_padding_mask, attn_mask, sampling_factor | |
) | |
transformer_layer_idx += 1 | |
else: | |
x = self.transformer_layers[layer_idx](x) | |
# encoder_padding_maks is a (T x B) tensor, its [t, b] elements indicate | |
# whether encoder_output[t, b] is valid or not (valid=0, invalid=1) | |
return { | |
"encoder_out": x, # (T, B, C) | |
"encoder_padding_mask": encoder_padding_mask.t() | |
if encoder_padding_mask is not None | |
else None, | |
# (B, T) --> (T, B) | |
} | |
def infer_conv_output_dim(self, in_channels, input_dim): | |
sample_seq_len = 200 | |
sample_bsz = 10 | |
x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim) | |
for i, _ in enumerate(self.conv_layers): | |
x = self.conv_layers[i](x) | |
x = x.transpose(1, 2) | |
mb, seq = x.size()[:2] | |
return x.contiguous().view(mb, seq, -1).size(-1) | |
def validate_transformer_config(self, transformer_config): | |
for config in transformer_config: | |
input_dim, num_heads = config[:2] | |
if input_dim % num_heads != 0: | |
msg = ( | |
"ERROR in transformer config {}: ".format(config) | |
+ "input dimension {} ".format(input_dim) | |
+ "not dividable by number of heads {}".format(num_heads) | |
) | |
raise ValueError(msg) | |
def parse_transformer_context(self, transformer_context): | |
""" | |
transformer_context can be the following: | |
- None; indicates no context is used, i.e., | |
transformer can access full context | |
- a tuple/list of two int; indicates left and right context, | |
any number <0 indicates infinite context | |
* e.g., (5, 6) indicates that for query at x_t, transformer can | |
access [t-5, t+6] (inclusive) | |
* e.g., (-1, 6) indicates that for query at x_t, transformer can | |
access [0, t+6] (inclusive) | |
""" | |
if transformer_context is None: | |
return None | |
if not isinstance(transformer_context, Iterable): | |
raise ValueError("transformer context must be Iterable if it is not None") | |
if len(transformer_context) != 2: | |
raise ValueError("transformer context must have length 2") | |
left_context = transformer_context[0] | |
if left_context < 0: | |
left_context = None | |
right_context = transformer_context[1] | |
if right_context < 0: | |
right_context = None | |
if left_context is None and right_context is None: | |
return None | |
return (left_context, right_context) | |
def parse_transformer_sampling(self, transformer_sampling, num_layers): | |
""" | |
parsing transformer sampling configuration | |
Args: | |
- transformer_sampling, accepted input: | |
* None, indicating no sampling | |
* an Iterable with int (>0) as element | |
- num_layers, expected number of transformer layers, must match with | |
the length of transformer_sampling if it is not None | |
Returns: | |
- A tuple with length num_layers | |
""" | |
if transformer_sampling is None: | |
return (1,) * num_layers | |
if not isinstance(transformer_sampling, Iterable): | |
raise ValueError( | |
"transformer_sampling must be an iterable if it is not None" | |
) | |
if len(transformer_sampling) != num_layers: | |
raise ValueError( | |
"transformer_sampling {} does not match with the number " | |
"of layers {}".format(transformer_sampling, num_layers) | |
) | |
for layer, value in enumerate(transformer_sampling): | |
if not isinstance(value, int): | |
raise ValueError("Invalid value in transformer_sampling: ") | |
if value < 1: | |
raise ValueError( | |
"{} layer's subsampling is {}.".format(layer, value) | |
+ " This is not allowed! " | |
) | |
return transformer_sampling | |
def slice(self, embedding, padding_mask, attn_mask, sampling_factor): | |
""" | |
embedding is a (T, B, D) tensor | |
padding_mask is a (B, T) tensor or None | |
attn_mask is a (T, T) tensor or None | |
""" | |
embedding = embedding[::sampling_factor, :, :] | |
if padding_mask is not None: | |
padding_mask = padding_mask[:, ::sampling_factor] | |
if attn_mask is not None: | |
attn_mask = attn_mask[::sampling_factor, ::sampling_factor] | |
return embedding, padding_mask, attn_mask | |
def lengths_to_attn_mask(self, input_lengths, subsampling_factor=1): | |
""" | |
create attention mask according to sequence lengths and transformer | |
context | |
Args: | |
- input_lengths: (B, )-shape Int/Long tensor; input_lengths[b] is | |
the length of b-th sequence | |
- subsampling_factor: int | |
* Note that the left_context and right_context is specified in | |
the input frame-level while input to transformer may already | |
go through subsampling (e.g., the use of striding in vggblock) | |
we use subsampling_factor to scale the left/right context | |
Return: | |
- a (T, T) binary tensor or None, where T is max(input_lengths) | |
* if self.transformer_context is None, None | |
* if left_context is None, | |
* attn_mask[t, t + right_context + 1:] = 1 | |
* others = 0 | |
* if right_context is None, | |
* attn_mask[t, 0:t - left_context] = 1 | |
* others = 0 | |
* elsif | |
* attn_mask[t, t - left_context: t + right_context + 1] = 0 | |
* others = 1 | |
""" | |
if self.transformer_context is None: | |
return None | |
maxT = torch.max(input_lengths).item() | |
attn_mask = torch.zeros(maxT, maxT) | |
left_context = self.transformer_context[0] | |
right_context = self.transformer_context[1] | |
if left_context is not None: | |
left_context = math.ceil(self.transformer_context[0] / subsampling_factor) | |
if right_context is not None: | |
right_context = math.ceil(self.transformer_context[1] / subsampling_factor) | |
for t in range(maxT): | |
if left_context is not None: | |
st = 0 | |
en = max(st, t - left_context) | |
attn_mask[t, st:en] = 1 | |
if right_context is not None: | |
st = t + right_context + 1 | |
st = min(st, maxT - 1) | |
attn_mask[t, st:] = 1 | |
return attn_mask.to(input_lengths.device) | |
def reorder_encoder_out(self, encoder_out, new_order): | |
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( | |
1, new_order | |
) | |
if encoder_out["encoder_padding_mask"] is not None: | |
encoder_out["encoder_padding_mask"] = encoder_out[ | |
"encoder_padding_mask" | |
].index_select(1, new_order) | |
return encoder_out | |
class TransformerDecoder(FairseqIncrementalDecoder): | |
""" | |
Transformer decoder consisting of *args.decoder_layers* layers. Each layer | |
is a :class:`TransformerDecoderLayer`. | |
Args: | |
args (argparse.Namespace): parsed command-line arguments | |
dictionary (~fairseq.data.Dictionary): decoding dictionary | |
embed_tokens (torch.nn.Embedding): output embedding | |
no_encoder_attn (bool, optional): whether to attend to encoder outputs. | |
Default: ``False`` | |
left_pad (bool, optional): whether the input is left-padded. Default: | |
``False`` | |
""" | |
def __init__( | |
self, | |
dictionary, | |
embed_dim=512, | |
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, | |
conv_config=DEFAULT_DEC_CONV_CONFIG, | |
encoder_output_dim=512, | |
): | |
super().__init__(dictionary) | |
vocab_size = len(dictionary) | |
self.padding_idx = dictionary.pad() | |
self.embed_tokens = Embedding(vocab_size, embed_dim, self.padding_idx) | |
self.conv_layers = nn.ModuleList() | |
for i in range(len(conv_config)): | |
out_channels, kernel_size, layer_norm = conv_config[i] | |
if i == 0: | |
conv_layer = LinearizedConv1d( | |
embed_dim, out_channels, kernel_size, padding=kernel_size - 1 | |
) | |
else: | |
conv_layer = LinearizedConv1d( | |
conv_config[i - 1][0], | |
out_channels, | |
kernel_size, | |
padding=kernel_size - 1, | |
) | |
self.conv_layers.append(conv_layer) | |
if layer_norm: | |
self.conv_layers.append(nn.LayerNorm(out_channels)) | |
self.conv_layers.append(nn.ReLU()) | |
self.layers = nn.ModuleList() | |
if conv_config[-1][0] != transformer_config[0][0]: | |
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0])) | |
self.layers.append( | |
TransformerDecoderLayer( | |
prepare_transformer_decoder_params(*transformer_config[0]) | |
) | |
) | |
for i in range(1, len(transformer_config)): | |
if transformer_config[i - 1][0] != transformer_config[i][0]: | |
self.layers.append( | |
Linear(transformer_config[i - 1][0], transformer_config[i][0]) | |
) | |
self.layers.append( | |
TransformerDecoderLayer( | |
prepare_transformer_decoder_params(*transformer_config[i]) | |
) | |
) | |
self.fc_out = Linear(transformer_config[-1][0], vocab_size) | |
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): | |
""" | |
Args: | |
prev_output_tokens (LongTensor): previous decoder outputs of shape | |
`(batch, tgt_len)`, for input feeding/teacher forcing | |
encoder_out (Tensor, optional): output from the encoder, used for | |
encoder-side attention | |
incremental_state (dict): dictionary used for storing state during | |
:ref:`Incremental decoding` | |
Returns: | |
tuple: | |
- the last decoder layer's output of shape `(batch, tgt_len, | |
vocab)` | |
- the last decoder layer's attention weights of shape `(batch, | |
tgt_len, src_len)` | |
""" | |
target_padding_mask = ( | |
(prev_output_tokens == self.padding_idx).to(prev_output_tokens.device) | |
if incremental_state is None | |
else None | |
) | |
if incremental_state is not None: | |
prev_output_tokens = prev_output_tokens[:, -1:] | |
# embed tokens | |
x = self.embed_tokens(prev_output_tokens) | |
# B x T x C -> T x B x C | |
x = self._transpose_if_training(x, incremental_state) | |
for layer in self.conv_layers: | |
if isinstance(layer, LinearizedConvolution): | |
x = layer(x, incremental_state) | |
else: | |
x = layer(x) | |
# B x T x C -> T x B x C | |
x = self._transpose_if_inference(x, incremental_state) | |
# decoder layers | |
for layer in self.layers: | |
if isinstance(layer, TransformerDecoderLayer): | |
x, *_ = layer( | |
x, | |
(encoder_out["encoder_out"] if encoder_out is not None else None), | |
( | |
encoder_out["encoder_padding_mask"].t() | |
if encoder_out["encoder_padding_mask"] is not None | |
else None | |
), | |
incremental_state, | |
self_attn_mask=( | |
self.buffered_future_mask(x) | |
if incremental_state is None | |
else None | |
), | |
self_attn_padding_mask=( | |
target_padding_mask if incremental_state is None else None | |
), | |
) | |
else: | |
x = layer(x) | |
# T x B x C -> B x T x C | |
x = x.transpose(0, 1) | |
x = self.fc_out(x) | |
return x, None | |
def buffered_future_mask(self, tensor): | |
dim = tensor.size(0) | |
if ( | |
not hasattr(self, "_future_mask") | |
or self._future_mask is None | |
or self._future_mask.device != tensor.device | |
): | |
self._future_mask = torch.triu( | |
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 | |
) | |
if self._future_mask.size(0) < dim: | |
self._future_mask = torch.triu( | |
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 | |
) | |
return self._future_mask[:dim, :dim] | |
def _transpose_if_training(self, x, incremental_state): | |
if incremental_state is None: | |
x = x.transpose(0, 1) | |
return x | |
def _transpose_if_inference(self, x, incremental_state): | |
if incremental_state: | |
x = x.transpose(0, 1) | |
return x | |
class VGGTransformerEncoderModel(FairseqEncoderModel): | |
def __init__(self, encoder): | |
super().__init__(encoder) | |
def add_args(parser): | |
"""Add model-specific arguments to the parser.""" | |
parser.add_argument( | |
"--input-feat-per-channel", | |
type=int, | |
metavar="N", | |
help="encoder input dimension per input channel", | |
) | |
parser.add_argument( | |
"--vggblock-enc-config", | |
type=str, | |
metavar="EXPR", | |
help=""" | |
an array of tuples each containing the configuration of one vggblock | |
[(out_channels, conv_kernel_size, pooling_kernel_size,num_conv_layers), ...] | |
""", | |
) | |
parser.add_argument( | |
"--transformer-enc-config", | |
type=str, | |
metavar="EXPR", | |
help=""" | |
a tuple containing the configuration of the Transformer layers | |
configurations: | |
[(input_dim, | |
num_heads, | |
ffn_dim, | |
normalize_before, | |
dropout, | |
attention_dropout, | |
relu_dropout), ]""", | |
) | |
parser.add_argument( | |
"--enc-output-dim", | |
type=int, | |
metavar="N", | |
help="encoder output dimension, projecting the LSTM output", | |
) | |
parser.add_argument( | |
"--in-channels", | |
type=int, | |
metavar="N", | |
help="number of encoder input channels", | |
) | |
parser.add_argument( | |
"--transformer-context", | |
type=str, | |
metavar="EXPR", | |
help=""" | |
either None or a tuple of two ints, indicating left/right context a | |
transformer can have access to""", | |
) | |
parser.add_argument( | |
"--transformer-sampling", | |
type=str, | |
metavar="EXPR", | |
help=""" | |
either None or a tuple of ints, indicating sampling factor in each layer""", | |
) | |
def build_model(cls, args, task): | |
"""Build a new model instance.""" | |
base_architecture_enconly(args) | |
encoder = VGGTransformerEncoderOnly( | |
vocab_size=len(task.target_dictionary), | |
input_feat_per_channel=args.input_feat_per_channel, | |
vggblock_config=eval(args.vggblock_enc_config), | |
transformer_config=eval(args.transformer_enc_config), | |
encoder_output_dim=args.enc_output_dim, | |
in_channels=args.in_channels, | |
transformer_context=eval(args.transformer_context), | |
transformer_sampling=eval(args.transformer_sampling), | |
) | |
return cls(encoder) | |
def get_normalized_probs(self, net_output, log_probs, sample=None): | |
# net_output['encoder_out'] is a (T, B, D) tensor | |
lprobs = super().get_normalized_probs(net_output, log_probs, sample) | |
# lprobs is a (T, B, D) tensor | |
# we need to transoose to get (B, T, D) tensor | |
lprobs = lprobs.transpose(0, 1).contiguous() | |
lprobs.batch_first = True | |
return lprobs | |
class VGGTransformerEncoderOnly(VGGTransformerEncoder): | |
def __init__( | |
self, | |
vocab_size, | |
input_feat_per_channel, | |
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG, | |
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, | |
encoder_output_dim=512, | |
in_channels=1, | |
transformer_context=None, | |
transformer_sampling=None, | |
): | |
super().__init__( | |
input_feat_per_channel=input_feat_per_channel, | |
vggblock_config=vggblock_config, | |
transformer_config=transformer_config, | |
encoder_output_dim=encoder_output_dim, | |
in_channels=in_channels, | |
transformer_context=transformer_context, | |
transformer_sampling=transformer_sampling, | |
) | |
self.fc_out = Linear(self.encoder_output_dim, vocab_size) | |
def forward(self, src_tokens, src_lengths, **kwargs): | |
""" | |
src_tokens: padded tensor (B, T, C * feat) | |
src_lengths: tensor of original lengths of input utterances (B,) | |
""" | |
enc_out = super().forward(src_tokens, src_lengths) | |
x = self.fc_out(enc_out["encoder_out"]) | |
# x = F.log_softmax(x, dim=-1) | |
# Note: no need this line, because model.get_normalized_prob will call | |
# log_softmax | |
return { | |
"encoder_out": x, # (T, B, C) | |
"encoder_padding_mask": enc_out["encoder_padding_mask"], # (T, B) | |
} | |
def max_positions(self): | |
"""Maximum input length supported by the encoder.""" | |
return (1e6, 1e6) # an arbitrary large number | |
def Embedding(num_embeddings, embedding_dim, padding_idx): | |
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) | |
# nn.init.uniform_(m.weight, -0.1, 0.1) | |
# nn.init.constant_(m.weight[padding_idx], 0) | |
return m | |
def Linear(in_features, out_features, bias=True, dropout=0): | |
"""Linear layer (input: N x T x C)""" | |
m = nn.Linear(in_features, out_features, bias=bias) | |
# m.weight.data.uniform_(-0.1, 0.1) | |
# if bias: | |
# m.bias.data.uniform_(-0.1, 0.1) | |
return m | |
def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): | |
"""Weight-normalized Conv1d layer optimized for decoding""" | |
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) | |
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) | |
nn.init.normal_(m.weight, mean=0, std=std) | |
nn.init.constant_(m.bias, 0) | |
return nn.utils.weight_norm(m, dim=2) | |
def LayerNorm(embedding_dim): | |
m = nn.LayerNorm(embedding_dim) | |
return m | |
# seq2seq models | |
def base_architecture(args): | |
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40) | |
args.vggblock_enc_config = getattr( | |
args, "vggblock_enc_config", DEFAULT_ENC_VGGBLOCK_CONFIG | |
) | |
args.transformer_enc_config = getattr( | |
args, "transformer_enc_config", DEFAULT_ENC_TRANSFORMER_CONFIG | |
) | |
args.enc_output_dim = getattr(args, "enc_output_dim", 512) | |
args.in_channels = getattr(args, "in_channels", 1) | |
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128) | |
args.transformer_dec_config = getattr( | |
args, "transformer_dec_config", DEFAULT_ENC_TRANSFORMER_CONFIG | |
) | |
args.conv_dec_config = getattr(args, "conv_dec_config", DEFAULT_DEC_CONV_CONFIG) | |
args.transformer_context = getattr(args, "transformer_context", "None") | |
def vggtransformer_1(args): | |
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) | |
args.vggblock_enc_config = getattr( | |
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" | |
) | |
args.transformer_enc_config = getattr( | |
args, | |
"transformer_enc_config", | |
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 14", | |
) | |
args.enc_output_dim = getattr(args, "enc_output_dim", 1024) | |
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128) | |
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") | |
args.transformer_dec_config = getattr( | |
args, | |
"transformer_dec_config", | |
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 4", | |
) | |
def vggtransformer_2(args): | |
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) | |
args.vggblock_enc_config = getattr( | |
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" | |
) | |
args.transformer_enc_config = getattr( | |
args, | |
"transformer_enc_config", | |
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16", | |
) | |
args.enc_output_dim = getattr(args, "enc_output_dim", 1024) | |
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512) | |
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") | |
args.transformer_dec_config = getattr( | |
args, | |
"transformer_dec_config", | |
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 6", | |
) | |
def vggtransformer_base(args): | |
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) | |
args.vggblock_enc_config = getattr( | |
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" | |
) | |
args.transformer_enc_config = getattr( | |
args, "transformer_enc_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 12" | |
) | |
args.enc_output_dim = getattr(args, "enc_output_dim", 512) | |
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512) | |
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") | |
args.transformer_dec_config = getattr( | |
args, "transformer_dec_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 6" | |
) | |
# Size estimations: | |
# Encoder: | |
# - vggblock param: 64*1*3*3 + 64*64*3*3 + 128*64*3*3 + 128*128*3 = 258K | |
# Transformer: | |
# - input dimension adapter: 2560 x 512 -> 1.31M | |
# - transformer_layers (x12) --> 37.74M | |
# * MultiheadAttention: 512*512*3 (in_proj) + 512*512 (out_proj) = 1.048M | |
# * FFN weight: 512*2048*2 = 2.097M | |
# - output dimension adapter: 512 x 512 -> 0.26 M | |
# Decoder: | |
# - LinearizedConv1d: 512 * 256 * 3 + 256 * 256 * 3 * 3 | |
# - transformer_layer: (x6) --> 25.16M | |
# * MultiheadAttention (self-attention): 512*512*3 + 512*512 = 1.048M | |
# * MultiheadAttention (encoder-attention): 512*512*3 + 512*512 = 1.048M | |
# * FFN: 512*2048*2 = 2.097M | |
# Final FC: | |
# - FC: 512*5000 = 256K (assuming vocab size 5K) | |
# In total: | |
# ~65 M | |
# CTC models | |
def base_architecture_enconly(args): | |
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40) | |
args.vggblock_enc_config = getattr( | |
args, "vggblock_enc_config", "[(32, 3, 2, 2, True)] * 2" | |
) | |
args.transformer_enc_config = getattr( | |
args, "transformer_enc_config", "((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2" | |
) | |
args.enc_output_dim = getattr(args, "enc_output_dim", 512) | |
args.in_channels = getattr(args, "in_channels", 1) | |
args.transformer_context = getattr(args, "transformer_context", "None") | |
args.transformer_sampling = getattr(args, "transformer_sampling", "None") | |
def vggtransformer_enc_1(args): | |
# vggtransformer_1 is the same as vggtransformer_enc_big, except the number | |
# of layers is increased to 16 | |
# keep it here for backward compatiablity purpose | |
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) | |
args.vggblock_enc_config = getattr( | |
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" | |
) | |
args.transformer_enc_config = getattr( | |
args, | |
"transformer_enc_config", | |
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16", | |
) | |
args.enc_output_dim = getattr(args, "enc_output_dim", 1024) | |