# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import os from argparse import Namespace from pathlib import Path import torch from fairseq.data import ( encoders, Dictionary, ResamplingDataset, TransformEosLangPairDataset, ConcatDataset, ) from fairseq.data.iterators import GroupedEpochBatchIterator from fairseq.data.audio.multi_modality_dataset import ( MultiModalityDataset, LangPairMaskDataset, ModalityDatasetItem, ) from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset, SpeechToTextDatasetCreator from fairseq.data.audio.speech_to_text_joint_dataset import ( S2TJointDataConfig, SpeechToTextJointDatasetCreator, ) from fairseq.tasks import register_task from fairseq.tasks.speech_to_text import SpeechToTextTask from fairseq.tasks.translation import load_langpair_dataset logger = logging.getLogger(__name__) LANG_TAG_TEMPLATE = "" @register_task("speech_text_joint_to_text") class SpeechTextJointToTextTask(SpeechToTextTask): """ Task for joint training speech and text to text. """ @classmethod def add_args(cls, parser): """Add task-specific arguments to the parser.""" super(SpeechTextJointToTextTask, cls).add_args(parser) ### parser.add_argument( "--parallel-text-data", default="", help="path to parallel text data directory", ) parser.add_argument( "--max-tokens-text", type=int, metavar="N", help="maximum tokens for encoder text input ", ) parser.add_argument( "--max-positions-text", type=int, metavar="N", default=400, help="maximum tokens for per encoder text input ", ) parser.add_argument( "--langpairs", default=None, metavar="S", help='language pairs for text training, separated with ","', ) parser.add_argument( "--speech-sample-ratio", default=1, type=float, metavar="N", help="Multiple Ratio for speech dataset with transcripts ", ) parser.add_argument( "--text-sample-ratio", default=1, type=float, metavar="N", help="Multiple Ratio for text set ", ) parser.add_argument( "--update-mix-data", action="store_true", help="use mixed data in one update when update-freq > 1", ) parser.add_argument( "--load-speech-only", action="store_true", help="load speech data only", ) parser.add_argument( "--mask-text-ratio", type=float, metavar="V", default=0.0, help="mask V source tokens for text only mode", ) parser.add_argument( "--mask-text-type", default="random", choices=["random", "tail"], help="mask text typed", ) parser.add_argument( "--noise-token", default="", help="noise token for masking src text tokens if mask-text-ratio > 0", ) parser.add_argument( "--infer-target-lang", default="", metavar="S", help="target language for inference", ) def __init__(self, args, src_dict, tgt_dict, infer_tgt_lang_id=None): super().__init__(args, tgt_dict) self.src_dict = src_dict self.data_cfg = S2TJointDataConfig(Path(args.data) / args.config_yaml) assert self.tgt_dict.pad() == self.src_dict.pad() assert self.tgt_dict.eos() == self.src_dict.eos() self.speech_only = args.load_speech_only self._infer_tgt_lang_id = infer_tgt_lang_id @classmethod def setup_task(cls, args, **kwargs): """Setup the task (e.g., load dictionaries).""" data_cfg = S2TJointDataConfig(Path(args.data) / args.config_yaml) tgt_dict_path = Path(args.data) / data_cfg.vocab_filename src_dict_path = Path(args.data) / data_cfg.src_vocab_filename if (not os.path.isfile(src_dict_path)) or (not os.path.isfile(tgt_dict_path)): raise FileNotFoundError("Dict not found: {}".format(args.data)) src_dict = Dictionary.load(src_dict_path.as_posix()) tgt_dict = Dictionary.load(tgt_dict_path.as_posix()) print("| src dictionary: {} types".format(len(src_dict))) print("| tgt dictionary: {} types".format(len(tgt_dict))) if args.parallel_text_data != "": if not os.path.isabs(args.parallel_text_data): args.parallel_text_data = os.path.join( args.data, args.parallel_text_data ) if args.langpairs is None: raise Exception( "Could not infer language pair, please provide it explicitly" ) infer_tgt_lang_id = None if args.infer_target_lang != "" and data_cfg.prepend_tgt_lang_tag_no_change: tgt_lang_tag = SpeechToTextDataset.LANG_TAG_TEMPLATE.format( args.infer_target_lang ) infer_tgt_lang_id = tgt_dict.index(tgt_lang_tag) assert infer_tgt_lang_id != tgt_dict.unk() return cls(args, src_dict, tgt_dict, infer_tgt_lang_id=infer_tgt_lang_id) def load_langpair_dataset(self, prepend_tgt_lang_tag=False, sampling_alpha=1.0, epoch=0): lang_pairs = [] text_dataset = None split = "train" for lp in self.args.langpairs.split(","): src, tgt = lp.split("-") text_dataset = load_langpair_dataset( self.args.parallel_text_data, split, src, self.src_dict, tgt, self.tgt_dict, combine=True, dataset_impl=None, upsample_primary=1, left_pad_source=False, left_pad_target=False, max_source_positions=self.args.max_positions_text, max_target_positions=self.args.max_target_positions, load_alignments=False, truncate_source=False, ) if prepend_tgt_lang_tag: # TODO text_dataset = TransformEosLangPairDataset( text_dataset, src_eos=self.src_dict.eos(), tgt_bos=self.tgt_dict.eos(), # 'prev_output_tokens' starts with eos new_tgt_bos=self.tgt_dict.index(LANG_TAG_TEMPLATE.format(tgt)), ) lang_pairs.append(text_dataset) if len(lang_pairs) > 1: if sampling_alpha != 1.0: size_ratios = SpeechToTextDatasetCreator.get_size_ratios( self.args.langpairs.split(","), [len(s) for s in lang_pairs], alpha=sampling_alpha, ) lang_pairs = [ ResamplingDataset( d, size_ratio=r, epoch=epoch, replace=(r >= 1.0) ) for d, r in zip(lang_pairs, size_ratios) ] return ConcatDataset(lang_pairs) return text_dataset def inference_step( self, generator, models, sample, prefix_tokens=None, constraints=None ): with torch.no_grad(): return generator.generate( models, sample, prefix_tokens=prefix_tokens, constraints=constraints, bos_token=self._infer_tgt_lang_id, ) def build_src_tokenizer(self, args): logger.info(f"src-pre-tokenizer: {self.data_cfg.src_pre_tokenizer}") return encoders.build_tokenizer(Namespace(**self.data_cfg.src_pre_tokenizer)) def build_src_bpe(self, args): logger.info(f"tokenizer: {self.data_cfg.src_bpe_tokenizer}") return encoders.build_bpe(Namespace(**self.data_cfg.src_bpe_tokenizer)) def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ is_train_split = split.startswith("train") pre_tokenizer = self.build_tokenizer(self.args) bpe_tokenizer = self.build_bpe(self.args) src_pre_tokenizer = self.build_src_tokenizer(self.args) src_bpe_tokenizer = self.build_src_bpe(self.args) ast_dataset = SpeechToTextJointDatasetCreator.from_tsv( self.args.data, self.data_cfg, split, self.tgt_dict, src_dict=None if self.speech_only else self.src_dict, pre_tokenizer=pre_tokenizer, bpe_tokenizer=bpe_tokenizer, src_pre_tokenizer=src_pre_tokenizer, src_bpe_tokenizer=src_bpe_tokenizer, is_train_split=is_train_split, epoch=epoch, seed=self.args.seed, ) noise_token_id = -1 text_dataset = None if self.args.parallel_text_data != "" and is_train_split: text_dataset = self.load_langpair_dataset( self.data_cfg.prepend_tgt_lang_tag_no_change, 1.0, epoch=epoch, ) if self.args.mask_text_ratio > 0: # add mask noise_token_id = ( self.src_dict.unk() if self.args.noise_token == "" else self.src_dict.index(self.args.noise_token) ) text_dataset = LangPairMaskDataset( text_dataset, src_bos=self.src_dict.bos(), src_eos=self.src_dict.eos(), noise_id=noise_token_id, mask_ratio=self.args.mask_text_ratio, mask_type=self.args.mask_text_type, ) if text_dataset is not None: mdsets = [ ModalityDatasetItem( "sup_speech", ast_dataset, (self.args.max_source_positions, self.args.max_target_positions), self.args.max_tokens, self.args.batch_size, ), ModalityDatasetItem( "text", text_dataset, (self.args.max_positions_text, self.args.max_target_positions), self.args.max_tokens_text if self.args.max_tokens_text is not None else self.args.max_tokens, self.args.batch_size, ), ] ast_dataset = MultiModalityDataset(mdsets) self.datasets[split] = ast_dataset @property def target_dictionary(self): """Return the :class:`~fairseq.data.Dictionary` for the language model.""" return self.tgt_dict @property def source_dictionary(self): """Return the source :class:`~fairseq.data.Dictionary` (if applicable for this task).""" return None if self.speech_only else self.src_dict def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0, data_buffer_size=0, disable_iterator_cache=False, ): if not isinstance(dataset, MultiModalityDataset): return super(SpeechTextJointToTextTask, self).get_batch_iterator( dataset, max_tokens, max_sentences, max_positions, ignore_invalid_inputs, required_batch_size_multiple, seed, num_shards, shard_id, num_workers, epoch, data_buffer_size, disable_iterator_cache, ) mult_ratio = [self.args.speech_sample_ratio, self.args.text_sample_ratio] assert len(dataset.datasets) == 2 # initialize the dataset with the correct starting epoch dataset.set_epoch(epoch) batch_samplers = dataset.get_batch_samplers( mult_ratio, required_batch_size_multiple, seed ) # return a reusable, sharded iterator epoch_iter = GroupedEpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, batch_samplers=batch_samplers, seed=seed, num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, epoch=epoch, mult_rate=1 if self.args.update_mix_data else max(self.args.update_freq), buffer_size=data_buffer_size, ) self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch return epoch_iter