# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging from pathlib import Path from argparse import Namespace from fairseq.data import Dictionary, encoders from fairseq.data.audio.speech_to_text_dataset import ( S2TDataConfig, SpeechToTextDataset, SpeechToTextDatasetCreator, get_features_or_waveform ) from fairseq.tasks import LegacyFairseqTask, register_task logger = logging.getLogger(__name__) @register_task("speech_to_text") class SpeechToTextTask(LegacyFairseqTask): @classmethod def add_args(cls, parser): parser.add_argument("data", help="manifest root path") parser.add_argument( "--config-yaml", type=str, default="config.yaml", help="Configuration YAML filename (under manifest root)", ) parser.add_argument( "--max-source-positions", default=6000, type=int, metavar="N", help="max number of tokens in the source sequence", ) parser.add_argument( "--max-target-positions", default=1024, type=int, metavar="N", help="max number of tokens in the target sequence", ) def __init__(self, args, tgt_dict): super().__init__(args) self.tgt_dict = tgt_dict self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) self.speaker_to_id = self._get_speaker_to_id() def _get_speaker_to_id(self): speaker_to_id = None speaker_set_filename = self.data_cfg.config.get("speaker_set_filename") if speaker_set_filename is not None: speaker_set_path = Path(self.args.data) / speaker_set_filename with open(speaker_set_path) as f: speaker_to_id = {r.strip(): i for i, r in enumerate(f)} return speaker_to_id @classmethod def setup_task(cls, args, **kwargs): data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) dict_path = Path(args.data) / data_cfg.vocab_filename if not dict_path.is_file(): raise FileNotFoundError(f"Dict not found: {dict_path.as_posix()}") tgt_dict = Dictionary.load(dict_path.as_posix()) logger.info( f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}" ) if getattr(args, "train_subset", None) is not None: if not all(s.startswith("train") for s in args.train_subset.split(",")): raise ValueError('Train splits should be named like "train*".') return cls(args, tgt_dict) def build_criterion(self, args): from fairseq import criterions if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1: raise ValueError( 'Please set "--ignore-prefix-size 1" since ' "target language ID token is prepended as BOS." ) return criterions.build_criterion(args, self) def load_dataset(self, split, epoch=1, combine=False, **kwargs): is_train_split = split.startswith("train") pre_tokenizer = self.build_tokenizer(self.args) bpe_tokenizer = self.build_bpe(self.args) self.datasets[split] = SpeechToTextDatasetCreator.from_tsv( self.args.data, self.data_cfg, split, self.tgt_dict, pre_tokenizer, bpe_tokenizer, is_train_split=is_train_split, epoch=epoch, seed=self.args.seed, speaker_to_id=self.speaker_to_id ) @property def target_dictionary(self): return self.tgt_dict @property def source_dictionary(self): return None def max_positions(self): return self.args.max_source_positions, self.args.max_target_positions def build_model(self, args): args.input_feat_per_channel = self.data_cfg.input_feat_per_channel args.input_channels = self.data_cfg.input_channels args.speaker_to_id = self.speaker_to_id return super(SpeechToTextTask, self).build_model(args) def build_generator( self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, ): if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: raise ValueError( 'Please set "--prefix-size 1" since ' "target language ID token is prepended as BOS." ) lang_token_ids = { i for s, i in self.tgt_dict.indices.items() if SpeechToTextDataset.is_lang_tag(s) } if extra_gen_cls_kwargs is None: extra_gen_cls_kwargs = {} extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids return super().build_generator( models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs ) def build_tokenizer(self, args): logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}") return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer)) def build_bpe(self, args): logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}") return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer)) def get_interactive_tokens_and_lengths(self, lines, encode_fn): n_frames = [get_features_or_waveform(p).shape[0] for p in lines] return lines, n_frames def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): return SpeechToTextDataset( "interactive", False, self.data_cfg, src_tokens, src_lengths )