Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import contextlib | |
import copy | |
import logging | |
import math | |
from argparse import Namespace | |
from dataclasses import dataclass, field | |
from typing import Any, Optional | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from omegaconf import II, MISSING, open_dict | |
from fairseq import checkpoint_utils, tasks, utils | |
from fairseq.dataclass import FairseqDataclass | |
from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
from fairseq.models import ( | |
BaseFairseqModel, | |
FairseqEncoder, | |
FairseqEncoderDecoderModel, | |
FairseqIncrementalDecoder, | |
register_model, | |
) | |
from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES | |
from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer | |
from fairseq.tasks import FairseqTask | |
logger = logging.getLogger(__name__) | |
class HubertAsrConfig(FairseqDataclass): | |
w2v_path: str = field(default=MISSING, metadata={"help": "path to hubert model"}) | |
no_pretrained_weights: bool = field( | |
default=False, | |
metadata={"help": "if true, does not load pretrained weights"}, | |
) | |
dropout_input: float = field( | |
default=0.0, | |
metadata={"help": "dropout to apply to the input (after feat extr)"}, | |
) | |
final_dropout: float = field( | |
default=0.0, | |
metadata={"help": "dropout after transformer and before final projection"}, | |
) | |
dropout: float = field( | |
default=0.0, | |
metadata={"help": "dropout probability inside hubert model"}, | |
) | |
attention_dropout: float = field( | |
default=0.0, | |
metadata={ | |
"help": "dropout probability for attention weights " "inside hubert model" | |
}, | |
) | |
activation_dropout: float = field( | |
default=0.0, | |
metadata={ | |
"help": "dropout probability after activation in FFN " "inside hubert model" | |
}, | |
) | |
encoder_embed_dim: Optional[int] = field( | |
default=768, metadata={"help": "encoder embedding dimension"} | |
) | |
# masking | |
apply_mask: bool = field( | |
default=False, metadata={"help": "apply masking during fine-tuning"} | |
) | |
mask_length: int = field( | |
default=10, metadata={"help": "repeat the mask indices multiple times"} | |
) | |
mask_prob: float = field( | |
default=0.5, | |
metadata={ | |
"help": "probability of replacing a token with mask " | |
"(normalized by length)" | |
}, | |
) | |
mask_selection: MASKING_DISTRIBUTION_CHOICES = field( | |
default="static", metadata={"help": "how to choose masks"} | |
) | |
mask_other: float = field( | |
default=0, | |
metadata={ | |
"help": "secondary mask argument " | |
"(used for more complex distributions), " | |
"see help in compute_mask_indices" | |
}, | |
) | |
no_mask_overlap: bool = field( | |
default=False, metadata={"help": "whether to allow masks to overlap"} | |
) | |
# channel masking | |
mask_channel_length: int = field( | |
default=10, | |
metadata={"help": "length of the mask for features (channels)"}, | |
) | |
mask_channel_prob: float = field( | |
default=0.0, | |
metadata={"help": "probability of replacing a feature with 0"}, | |
) | |
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( | |
default="static", | |
metadata={"help": "how to choose mask length for channel masking"}, | |
) | |
mask_channel_other: float = field( | |
default=0, | |
metadata={ | |
"help": "secondary mask argument " | |
"(used for more complex distributions), " | |
"see help in compute_mask_indices" | |
}, | |
) | |
no_mask_channel_overlap: bool = field( | |
default=False, | |
metadata={"help": "whether to allow channel masks to overlap"}, | |
) | |
freeze_finetune_updates: int = field( | |
default=0, | |
metadata={"help": "dont finetune hubert for this many updates"}, | |
) | |
feature_grad_mult: float = field( | |
default=0.0, | |
metadata={"help": "reset feature grad mult in hubert to this"}, | |
) | |
layerdrop: float = field( | |
default=0.0, | |
metadata={"help": "probability of dropping a layer in hubert"}, | |
) | |
normalize: bool = II("task.normalize") | |
data: str = II("task.data") | |
# this holds the loaded hubert args | |
w2v_args: Any = None | |
class HubertCtcConfig(HubertAsrConfig): | |
pass | |
class HubertCtc(BaseFairseqModel): | |
def __init__(self, cfg: HubertCtcConfig, w2v_encoder: BaseFairseqModel): | |
super().__init__() | |
self.cfg = cfg | |
self.w2v_encoder = w2v_encoder | |
def upgrade_state_dict_named(self, state_dict, name): | |
super().upgrade_state_dict_named(state_dict, name) | |
return state_dict | |
def build_model(cls, cfg: HubertCtcConfig, task: FairseqTask): | |
"""Build a new model instance.""" | |
w2v_encoder = HubertEncoder(cfg, task) | |
return cls(cfg, w2v_encoder) | |
def get_normalized_probs(self, net_output, log_probs): | |
"""Get normalized probabilities (or log probs) from a net's output.""" | |
logits = net_output["encoder_out"] | |
if log_probs: | |
return utils.log_softmax(logits.float(), dim=-1) | |
else: | |
return utils.softmax(logits.float(), dim=-1) | |
def get_logits(self, net_output): | |
logits = net_output["encoder_out"] | |
padding = net_output["encoder_padding_mask"] | |
if padding is not None and padding.any(): | |
padding = padding.T | |
logits[padding][..., 0] = 0 | |
logits[padding][..., 1:] = float("-inf") | |
return logits | |
def forward(self, **kwargs): | |
x = self.w2v_encoder(**kwargs) | |
return x | |
class HubertSeq2SeqConfig(HubertAsrConfig): | |
decoder_embed_dim: int = field( | |
default=768, metadata={"help": "decoder embedding dimension"} | |
) | |
decoder_ffn_embed_dim: int = field( | |
default=3072, metadata={"help": "decoder embedding dimension for FFN"} | |
) | |
decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"}) | |
decoder_layerdrop: float = field( | |
default=0.0, metadata={"help": "decoder layerdrop chance"} | |
) | |
decoder_attention_heads: int = field( | |
default=4, metadata={"help": "num decoder attention heads"} | |
) | |
decoder_learned_pos: bool = field( | |
default=False, | |
metadata={"help": "use learned positional embeddings in the decoder"}, | |
) | |
decoder_normalize_before: bool = field( | |
default=False, metadata={"help": "apply layernorm before each decoder block"} | |
) | |
no_token_positional_embeddings: bool = field( | |
default=False, | |
metadata={ | |
"help": "if set, disables positional embeddings (outside self attention)" | |
}, | |
) | |
decoder_dropout: float = field( | |
default=0.0, metadata={"help": "dropout probability in the decoder"} | |
) | |
decoder_attention_dropout: float = field( | |
default=0.0, | |
metadata={ | |
"help": "dropout probability for attention weights inside the decoder" | |
}, | |
) | |
decoder_activation_dropout: float = field( | |
default=0.0, | |
metadata={ | |
"help": "dropout probability after activation in FFN inside the decoder" | |
}, | |
) | |
max_target_positions: int = field( | |
default=2048, metadata={"help": "max target positions"} | |
) | |
share_decoder_input_output_embed: bool = field( | |
default=False, metadata={"help": "share decoder input and output embeddings"} | |
) | |
autoregressive: bool = II("task.autoregressive") | |
seq2seq_path: str = field( | |
default="", | |
metadata={"help": "reset_dict"}, | |
) | |
reset_dict: bool = field( | |
default=False, | |
metadata={"help": "reset_dict"}, | |
) | |
class HubertSeq2SeqModel(FairseqEncoderDecoderModel): | |
def __init__(self, encoder, decoder): | |
super().__init__(encoder, decoder) | |
def build_model(cls, cfg: HubertSeq2SeqConfig, task: FairseqTask): | |
"""Build a new model instance.""" | |
assert ( | |
cfg.autoregressive | |
), "Please set task.autoregressive=true for seq2seq asr models" | |
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary | |
def build_embedding(dictionary, embed_dim): | |
num_embeddings = len(dictionary) | |
padding_idx = dictionary.pad() | |
emb = Embedding(num_embeddings, embed_dim, padding_idx) | |
return emb | |
decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim) | |
encoder = cls.build_encoder(cfg, task) | |
decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) | |
model = HubertSeq2SeqModel(encoder, decoder) | |
if cfg["seq2seq_path"]: | |
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.seq2seq_path) | |
state = state["model"] | |
if cfg["reset_dict"]: | |
del state["decoder.embed_out"] | |
del state["decoder.embed_tokens.weight"] | |
model.load_state_dict(state, strict=False) | |
return model | |
def build_encoder(cls, cfg: HubertAsrConfig, task): | |
return HubertEncoder(cfg, task) | |
def build_decoder(cls, cfg: HubertSeq2SeqConfig, tgt_dict, embed_tokens): | |
return TransformerDecoder(cfg, tgt_dict, embed_tokens) | |
def forward(self, **kwargs): | |
encoder_out = self.encoder(**kwargs) | |
decoder_out = self.decoder(encoder_out=encoder_out, **kwargs) | |
return decoder_out | |
def upgrade_state_dict_named(self, state_dict, name): | |
return state_dict | |
def load_state_dict( | |
self, | |
state_dict, | |
strict=True, | |
model_cfg=None, | |
args: Optional[Namespace] = None, | |
): | |
if model_cfg.reset_dict: | |
logger.warn("Overriding loading strict state dict!") | |
del state_dict["decoder.embed_out"] | |
del state_dict["decoder.embed_tokens.weight"] | |
return super().load_state_dict(state_dict, False, model_cfg, args) | |
return super().load_state_dict(state_dict, strict, model_cfg, args) | |
class HubertEncoder(FairseqEncoder): | |
def __init__(self, cfg: HubertAsrConfig, task): | |
self.apply_mask = cfg.apply_mask | |
arg_overrides = { | |
"dropout": cfg.dropout, | |
"activation_dropout": cfg.activation_dropout, | |
"dropout_input": cfg.dropout_input, | |
"attention_dropout": cfg.attention_dropout, | |
"mask_length": cfg.mask_length, | |
"mask_prob": cfg.mask_prob, | |
"mask_selection": cfg.mask_selection, | |
"mask_other": cfg.mask_other, | |
"no_mask_overlap": cfg.no_mask_overlap, | |
"mask_channel_length": cfg.mask_channel_length, | |
"mask_channel_prob": cfg.mask_channel_prob, | |
"mask_channel_selection": cfg.mask_channel_selection, | |
"mask_channel_other": cfg.mask_channel_other, | |
"no_mask_channel_overlap": cfg.no_mask_channel_overlap, | |
"encoder_layerdrop": cfg.layerdrop, | |
"feature_grad_mult": cfg.feature_grad_mult, | |
} | |
if cfg.w2v_args is None: | |
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides) | |
w2v_args = state.get("cfg", None) | |
if w2v_args is None: | |
w2v_args = convert_namespace_to_omegaconf(state["args"]) | |
cfg.w2v_args = w2v_args | |
else: | |
state = None | |
w2v_args = cfg.w2v_args | |
if isinstance(w2v_args, Namespace): | |
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) | |
assert cfg.normalize == w2v_args.task.normalize, ( | |
"Fine-tuning works best when data normalization is the same. " | |
"Please check that --normalize is set or unset for " | |
"both pre-training and here" | |
) | |
w2v_args.task.data = cfg.data | |
pretrain_task = tasks.setup_task(w2v_args.task) | |
if state is not None and "task_state" in state: | |
# This will load the stored "dictionaries" object | |
pretrain_task.load_state_dict(state["task_state"]) | |
else: | |
pretrain_task.load_state_dict(task.state_dict()) | |
model = pretrain_task.build_model(w2v_args.model, from_checkpoint=True) | |
if state is not None and not cfg.no_pretrained_weights: | |
# set strict=False because we omit some modules | |
model.load_state_dict(state["model"], strict=False) | |
model.remove_pretraining_modules() | |
super().__init__(pretrain_task.source_dictionary) | |
d = w2v_args.model.encoder_embed_dim | |
self.w2v_model = model | |
self.final_dropout = nn.Dropout(cfg.final_dropout) | |
self.freeze_finetune_updates = cfg.freeze_finetune_updates | |
self.num_updates = 0 | |
if task.target_dictionary is not None and not cfg.autoregressive: | |
self.proj = Linear(d, len(task.target_dictionary)) | |
elif getattr(cfg, "decoder_embed_dim", d) != d: | |
self.proj = Linear(d, cfg.decoder_embed_dim) | |
else: | |
self.proj = None | |
def set_num_updates(self, num_updates): | |
"""Set the number of parameters updates.""" | |
super().set_num_updates(num_updates) | |
self.num_updates = num_updates | |
def forward(self, source, padding_mask, tbc=True, **kwargs): | |
w2v_args = { | |
"source": source, | |
"padding_mask": padding_mask, | |
"mask": self.apply_mask and self.training, | |
} | |
ft = self.freeze_finetune_updates <= self.num_updates | |
with torch.no_grad() if not ft else contextlib.ExitStack(): | |
x, padding_mask = self.w2v_model.extract_features(**w2v_args) | |
if tbc: | |
# B x T x C -> T x B x C | |
x = x.transpose(0, 1) | |
x = self.final_dropout(x) | |
if self.proj: | |
x = self.proj(x) | |
return { | |
"encoder_out": x, # T x B x C | |
"encoder_padding_mask": padding_mask, # B x T | |
"padding_mask": padding_mask, | |
} | |
def reorder_encoder_out(self, encoder_out, new_order): | |
if encoder_out["encoder_out"] is not None: | |
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( | |
1, new_order | |
) | |
if encoder_out["encoder_padding_mask"] is not None: | |
encoder_out["encoder_padding_mask"] = encoder_out[ | |
"encoder_padding_mask" | |
].index_select(0, new_order) | |
if encoder_out["padding_mask"] is not None: | |
encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select( | |
0, new_order | |
) | |
return encoder_out | |
def max_positions(self): | |
"""Maximum input length supported by the encoder.""" | |
return None | |
def upgrade_state_dict_named(self, state_dict, name): | |
return state_dict | |
class TransformerDecoder(FairseqIncrementalDecoder): | |
""" | |
Transformer decoder consisting of *args.decoder_layers* layers. Each layer | |
is a :class:`TransformerDecoderLayer`. | |
Args: | |
args (argparse.Namespace): parsed command-line arguments | |
dictionary (~fairseq.data.Dictionary): decoding dictionary | |
embed_tokens (torch.nn.Embedding): output embedding | |
no_encoder_attn (bool, optional): whether to attend to encoder outputs | |
(default: False). | |
""" | |
def __init__( | |
self, | |
cfg: HubertSeq2SeqConfig, | |
dictionary, | |
embed_tokens, | |
no_encoder_attn=False, | |
): | |
super().__init__(dictionary) | |
self.dropout = cfg.decoder_dropout | |
self.share_input_output_embed = cfg.share_decoder_input_output_embed | |
input_embed_dim = embed_tokens.embedding_dim | |
embed_dim = cfg.decoder_embed_dim | |
self.output_embed_dim = cfg.decoder_embed_dim | |
self.layerdrop = cfg.decoder_layerdrop | |
self.padding_idx = embed_tokens.padding_idx | |
self.max_target_positions = cfg.max_target_positions | |
self.embed_tokens = embed_tokens | |
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim | |
self.project_in_dim = ( | |
Linear(input_embed_dim, embed_dim, bias=False) | |
if embed_dim != input_embed_dim | |
else None | |
) | |
self.embed_positions = ( | |
PositionalEmbedding( | |
cfg.max_target_positions, | |
embed_dim, | |
self.padding_idx, | |
learned=cfg.decoder_learned_pos, | |
) | |
if not cfg.no_token_positional_embeddings | |
else None | |
) | |
# TODO: update this when transformer gets converted to dataclass configs | |
transformer_cfg = copy.deepcopy(cfg) | |
with open_dict(transformer_cfg): | |
transformer_cfg.dropout = transformer_cfg.decoder_dropout | |
transformer_cfg.attention_dropout = ( | |
transformer_cfg.decoder_attention_dropout | |
) | |
transformer_cfg.activation_dropout = ( | |
transformer_cfg.decoder_activation_dropout | |
) | |
self.layers = nn.ModuleList([]) | |
self.layers.extend( | |
[ | |
TransformerDecoderLayer(transformer_cfg, no_encoder_attn) | |
for _ in range(transformer_cfg.decoder_layers) | |
] | |
) | |
if not self.share_input_output_embed: | |
self.embed_out = nn.Parameter( | |
torch.Tensor(len(dictionary), self.output_embed_dim) | |
) | |
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim**-0.5) | |
if transformer_cfg.decoder_normalize_before: | |
self.layer_norm = LayerNorm(embed_dim) | |
else: | |
self.layer_norm = None | |
def forward( | |
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused | |
): | |
""" | |
Args: | |
prev_output_tokens (LongTensor): previous decoder outputs of shape | |
`(batch, tgt_len)`, for teacher forcing | |
encoder_out (Tensor, optional): output from the encoder, used for | |
encoder-side attention | |
incremental_state (dict): dictionary used for storing state during | |
:ref:`Incremental decoding` | |
Returns: | |
tuple: | |
- the decoder's output of shape `(batch, tgt_len, vocab)` | |
- a dictionary with any model-specific outputs | |
""" | |
if type(prev_output_tokens) == list: | |
max_len = max((len(x) for x in prev_output_tokens)) | |
tmp = torch.zeros( | |
[len(prev_output_tokens), max_len], device=prev_output_tokens[0].device | |
) | |
for (i, p) in enumerate(prev_output_tokens): | |
tmp[i, : len(p)] = p | |
prev_output_tokens = tmp | |
prev_output_tokens = prev_output_tokens.long() | |
x, extra = self.extract_features( | |
prev_output_tokens, encoder_out, incremental_state | |
) | |
x = self.output_layer(x) | |
return x, extra | |
def extract_features( | |
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused | |
): | |
""" | |
Similar to *forward* but only return features. | |
Returns: | |
tuple: | |
- the decoder's features of shape `(batch, tgt_len, embed_dim)` | |
- a dictionary with any model-specific outputs | |
""" | |
# embed positions | |
positions = ( | |
self.embed_positions( | |
prev_output_tokens, incremental_state=incremental_state | |
) | |
if self.embed_positions is not None | |
else None | |
) | |
if incremental_state is not None: | |
prev_output_tokens = prev_output_tokens[:, -1:] | |
if positions is not None: | |
positions = positions[:, -1:] | |
# embed tokens and positions | |
x = self.embed_scale * self.embed_tokens(prev_output_tokens) | |
if self.project_in_dim is not None: | |
x = self.project_in_dim(x) | |
if positions is not None: | |
x += positions | |
x = F.dropout(x, p=self.dropout, training=self.training) | |
# B x T x C -> T x B x C | |
x = x.transpose(0, 1) | |
attn = None | |
inner_states = [x] | |
# decoder layers | |
self_attn_padding_mask = None | |
if prev_output_tokens.eq(self.padding_idx).any(): | |
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) | |
for layer in self.layers: | |
dropout_probability = np.random.random() | |
if not self.training or (dropout_probability > self.layerdrop): | |
x, attn, _ = layer( | |
x, | |
encoder_out["encoder_out"] if encoder_out is not None else None, | |
encoder_out["padding_mask"] if encoder_out is not None else None, | |
incremental_state, | |
self_attn_mask=self.buffered_future_mask(x) | |
if incremental_state is None | |
else None, | |
self_attn_padding_mask=self_attn_padding_mask, | |
) | |
inner_states.append(x) | |
if self.layer_norm: | |
x = self.layer_norm(x) | |
# T x B x C -> B x T x C | |
x = x.transpose(0, 1) | |
return x, {"attn": attn, "inner_states": inner_states} | |
def output_layer(self, features, **kwargs): | |
"""Project features to the vocabulary size.""" | |
# project back to size of vocabulary | |
if self.share_input_output_embed: | |
return F.linear(features, self.embed_tokens.weight) | |
else: | |
return F.linear(features, self.embed_out) | |
def max_positions(self): | |
"""Maximum output length supported by the decoder.""" | |
if self.embed_positions is None: | |
return self.max_target_positions | |
return min(self.max_target_positions, self.embed_positions.max_positions) | |
def buffered_future_mask(self, tensor): | |
dim = tensor.size(0) | |
if ( | |
not hasattr(self, "_future_mask") | |
or self._future_mask is None | |
or self._future_mask.device != tensor.device | |
or self._future_mask.size(0) < dim | |
): | |
self._future_mask = torch.triu( | |
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 | |
) | |
return self._future_mask[:dim, :dim] | |
def upgrade_state_dict_named(self, state_dict, name): | |
return state_dict | |
def Embedding(num_embeddings, embedding_dim, padding_idx): | |
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) | |
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) | |
nn.init.constant_(m.weight[padding_idx], 0) | |
return m | |
def Linear(in_features, out_features, bias=True): | |
m = nn.Linear(in_features, out_features, bias) | |
nn.init.xavier_uniform_(m.weight) | |
if bias: | |
nn.init.constant_(m.bias, 0.0) | |
return m | |