# -------------------------------------------------------- # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) # Github source: https://github.com/mbzuai-nlp/ArTST # Based on speecht5, fairseq and espnet code bases # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet # -------------------------------------------------------- import logging import os.path as op from argparse import Namespace from collections import OrderedDict import torch from fairseq.data import ( Dictionary, encoders, PrependTokenDataset, AppendTokenDataset, data_utils, StripTokenDataset, TokenBlockDataset, ) from fairseq.data.encoders.utils import get_whole_word_mask from fairseq import utils from artst.data.multitask_dataset import MultitaskDataset from artst.data.speech_to_text_dataset import SpeechToTextDataset from artst.data.text_to_speech_dataset import TextToSpeechDataset from artst.data.speech_to_speech_dataset import SpeechToSpeechDataset from artst.data.speech_to_class_dataset import SpeechToClassDataset from artst.data.speech_dataset import SpeechPretrainDataset from artst.data.text_dataset import TextPretrainDataset from fairseq.data.shorten_dataset import maybe_shorten_dataset from fairseq.tasks import LegacyFairseqTask, register_task from fairseq.tasks.hubert_pretraining import LabelEncoder logger = logging.getLogger(__name__) TASK_NAME = ["s2t", "t2s", "s2s", "s2c", "pretrain"] @register_task("artst") class ArTSTTask(LegacyFairseqTask): @staticmethod def add_args(parser): parser.add_argument("data", help="manifest root path") parser.add_argument( "--config-yaml", type=str, default="config.yaml", help="Configuration YAML filename (under manifest root)", ) parser.add_argument( "--max-speech-sample-size", default=None, type=int, metavar="N", help="max speech sample size", ) parser.add_argument( "--min-speech-sample-size", default=None, type=int, metavar="N", help="min speech sample size", ) parser.add_argument( "--max-speech-positions", default=4000, type=int, metavar="N", help="max number of tokens in the source sequence", ) parser.add_argument( "--max-text-positions", default=450, type=int, metavar="N", help="max number of tokens in the target sequence", ) parser.add_argument( '--t5-task', choices=TASK_NAME, help='task for training' ) parser.add_argument( "--bpe-tokenizer", type=str, default=None, help="bpe tokenizer for s2t", ) # Speaker Identification (SID) parser.add_argument( "--finetune-from-modules", default=None, # choices=[ # "encoder-decoder", "encoder", "decoder", # "speech_encoder_prenet-encoder-decoder-text_decoder_prenet-text_decoder_postnet", # ASR, T5 SID # "speech_encoder_prenet-encoder-decoder-text_decoder_prenet-speaker_decoder_postnet", # SID # "speech_encoder_prenet-encoder-decoder-speech_decoder_prenet-speech_decoder_postnet", # VC, SE # "text_encoder_prenet-encoder-decoder-speech_decoder_prenet-speech_decoder_postnet", # TTS # ], help="If set, using part modules of finetune model.", ) parser.add_argument( "--finetune-out-of-modules", default=None, # choices=[ # "speaker_decoder_postnet", # SID # "speech_decoder_postnet", # SE with reduction factor 1 # ], help="If set, remove part modules of finetune model.", ) # BART parser.add_argument( "--shorten-method", default="none", choices=["none", "truncate", "random_crop"], help="if not none, shorten sequences that exceed --tokens-per-sample", ) parser.add_argument( "--shorten-data-split-list", default="", help="comma-separated list of dataset splits to apply shortening to, " 'e.g., "train,valid" (default: all dataset splits)', ) parser.add_argument( "--tokens-per-sample", default=512, type=int, help="max number of total tokens over all segments" " per sample for dataset", ) parser.add_argument( "--sample-break-mode", default="eos", type=str, help="mode for breaking sentence", ) parser.add_argument( "--mask", default=0.3, type=float, help="fraction of words/subwords that will be masked", ) parser.add_argument( "--mask-random", default=0.1, type=float, help="instead of using [MASK], use random token this often", ) parser.add_argument( "--insert", default=0.0, type=float, help="insert this percentage of additional random tokens", ) parser.add_argument( "--permute", default=0.0, type=float, help="take this proportion of subwords and permute them", ) parser.add_argument( "--rotate", default=0.0, type=float, help="rotate this proportion of inputs", ) parser.add_argument( "--poisson-lambda", default=3.5, type=float, help="randomly shuffle sentences for this proportion of inputs", ) parser.add_argument( "--permute-sentences", default=0.0, type=float, help="shuffle this proportion of sentences in all inputs", ) # parser.add_argument( # "--mask-length", # default="span-poisson", # type=str, # choices=["subword", "word", "span-poisson"], # help="mask length to choose", # ) parser.add_argument( "--replace-length", default=1, type=int, help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)", ) parser.add_argument( "--iid-noise-target", action="store_true", help="whether to use t5 form target", ) # Hubert parser.add_argument( "--hubert-labels", nargs="*", type=str, default=['km'], help="extension of the label files to load, frame-level labels for pre-training, and sequence-level label for fine-tuning", ) parser.add_argument( "--hubert-label-dir", type=str, default=None, help="if set, looks for labels in this directory instead", ) parser.add_argument( "--sample-rate", default=100, type=float, help="target sample rate. audio files will be up/down sampled to this rate", ) parser.add_argument( "--label-rates", default=-1, type=float, help="if set, looks for labels in this directory instead", ) parser.add_argument( "--normalize", action="store_true", help="if set, normalizes input to have 0 mean and unit variance", ) parser.add_argument( "--enable-padding", action="store_true", help="pad shorter samples instead of cropping", ) parser.add_argument( "--pad-audio", action="store_true", help="pad audio to the longest one in the batch if true", ) parser.add_argument( "--random-crop", action="store_true", help="always crop from the beginning if false", ) parser.add_argument( "--single-target", action="store_true", help="if set, AddTargetDatasets outputs same keys " "as AddTargetDataset", ) parser.add_argument( "--batch-ratio", default=None, type=str, help="ratio of bach size for each dataset", ) parser.add_argument( "--sample-ratios", default=None, type=str, help="ratio of sample for each dataset", ) parser.add_argument( "--ctc-weight", type=float, default=0.0, help="ctc weight for inference", ) parser.add_argument( "--inference-speech", type=bool, default=False, help="inference for TTS", ) def __init__(self, args, dicts, config): super().__init__(args) self.dicts = dicts self.config = config self.t5_task = args.t5_task # Used for filter size if self.t5_task in ['s2t', 't2s', 's2s', 's2c']: self.max_pos = [self.args.max_speech_positions * 256] elif self.t5_task == 'pretrain': self.max_pos = [self.args.max_speech_positions * 256, self.args.max_text_positions] self.mask_idx = self.dicts["text"].add_symbol("") # add blank token for ctc # if args.ctc_weight > 0: self.blank_symbol_idx = self.dicts["text"].add_symbol("") self.blank_symbol = "" # add mask token if hasattr(args, "iid_noise_target") and args.iid_noise_target: self.uni_mask_idxs = [] for i in range(600): self.uni_mask_idxs.append(self.dicts["text"].add_symbol("" + str(i))) self.uni_mask_idxs = torch.tensor(self.uni_mask_idxs) self.seed = args.seed @classmethod def setup_task(cls, args, **kwargs): # load dictionaries and config dicts = OrderedDict() if args.t5_task == 'pretrain' and not hasattr(args, "shuffle_instance"): args.shuffle_instance = False # Prepare config config = None logger.info('No config file for ' + args.t5_task) if args.t5_task == "pretrain": dicts["hubert"] = [Dictionary.load(f"{args.hubert_label_dir}/dict.{label}.txt") for label in args.hubert_labels] dicts["text"] = Dictionary.load(op.join(args.data, "dict.txt")) else: if config is None: dicts["text"] = Dictionary.load(op.join(args.data, "dict.txt")) else: dicts["text"] = Dictionary.load(op.join(args.data, config.vocab_filename)) return cls(args, dicts, config) def build_criterion(self, args): from fairseq import criterions return criterions.build_criterion(args, self) def load_dataset(self, split, epoch=1, combine=False, **kwargs): sample_ratios = [] if self.t5_task == "s2t": ## For speech to text task bpe_tokenizer = self.build_bpe(self.args) manifest = f"{self.args.data}/{split}.tsv" procs = [LabelEncoder(self.dicts["text"])] paths = [f"{self.args.hubert_label_dir}/{split}.txt"] # Hawau: view dataset... logger.info(f"Manifest: {manifest}") # logger.info(f"Paths: {paths}") self.datasets[split] = SpeechToTextDataset( manifest, sample_rate=self.args.sample_rate, label_paths=paths, label_processors=procs, max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size, min_keep_sample_size=self.args.min_speech_sample_size, normalize=self.args.normalize, store_labels=False, tgt_dict=self.dicts["text"], tokenizer=bpe_tokenizer, ) elif self.t5_task == "t2s": ## For text to speech task from fairseq.data import ConcatDataset bpe_tokenizer = self.build_bpe(self.args) procs = [LabelEncoder(self.dicts["text"])] t2s_datasets = [ TextToSpeechDataset( manifest_path=f"{self.args.data}/{name}.tsv", sample_rate=self.args.sample_rate, label_paths=[f"{self.args.hubert_label_dir}/{name}.txt"], label_processors=procs, max_keep_sample_size=self.max_pos[0], normalize=self.args.normalize, store_labels=False, src_dict=self.dicts["text"], tokenizer=bpe_tokenizer, reduction_factor=self.args.reduction_factor, inference=self.args.inference_speech, ) for name in split.split(",") ] self.datasets[split] = ConcatDataset(t2s_datasets) if len(t2s_datasets) > 1 else t2s_datasets[0] elif self.t5_task == "s2s": manifest = f"{self.args.data}/{split}.tsv" self.datasets[split] = SpeechToSpeechDataset( manifest_path=manifest, sample_rate=self.args.sample_rate, max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size, min_keep_sample_size=self.args.min_speech_sample_size, normalize=self.args.normalize, reduction_factor=self.args.reduction_factor, ) elif self.t5_task == "s2c": is_train_split = ("train" in split) is_valid_split = ("valid" in split) if is_train_split: max_length = 51200 elif is_valid_split: max_length = 76800 else: max_length = 2560000 manifest = op.join(f"{self.args.data}", f"{split}.tsv") procs = LabelEncoder(self.dicts["text"]) # map speaker to id self.datasets[split] = SpeechToClassDataset( manifest_path=manifest, sample_rate=self.args.sample_rate, label_processors=procs, max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size, min_keep_sample_size=self.args.min_speech_sample_size, normalize=self.args.normalize, tgt_dict=self.dicts["text"], max_length=max_length ) elif self.t5_task == "pretrain": is_train_split = ("train" in split) pretrain_datasets = [] speech_split, text_split = split.split('|') ## Speech pre-train manifest = f"{self.args.data}/{speech_split}.tsv" dicts = self.dicts["hubert"] pad_list = [dict.pad() for dict in dicts] eos_list = [dict.eos() for dict in dicts] procs = [LabelEncoder(dict) for dict in dicts] paths = [ f"{self.args.hubert_label_dir}/{speech_split}.{l}" for l in self.args.hubert_labels ] # hubert v1: pad_audio=True, random_crop=False; self.args.dec_weight = getattr(self.args, "dec_weight", 1.0) pretrain_datasets.append( SpeechPretrainDataset( manifest, sample_rate=self.args.sample_rate, label_paths=paths, label_rates=self.args.label_rates, pad_list=pad_list, eos_list=eos_list, label_processors=procs, max_keep_sample_size=None, min_keep_sample_size=32000, max_sample_size=self.args.max_speech_sample_size, pad_audio=self.args.pad_audio, normalize=self.args.normalize, store_labels=False, random_crop=self.args.random_crop, single_target=self.args.single_target, reduction_factor=self.args.reduction_factor, ) ) sample_ratios.append(sum([pretrain_datasets[0].size(i) for i in range(len(pretrain_datasets[0]))])) ## Text pre-train paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] print(f"Loading {text_split} from data_path={data_path}") split_path = op.join(data_path, text_split) print(f"split_path={split_path}") bart_dataset = data_utils.load_indexed_dataset( split_path, self.dicts["text"], self.args.dataset_impl, combine=combine, ) if bart_dataset is None: raise FileNotFoundError( "Dataset not found: {} ({})".format(text_split, split_path) ) bart_dataset = StripTokenDataset(bart_dataset, self.dicts["text"].eos()) bart_dataset = maybe_shorten_dataset( bart_dataset, text_split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens bart_dataset = TokenBlockDataset( bart_dataset, bart_dataset.sizes, self.args.tokens_per_sample - 2, # one less for and one for pad=self.dicts["text"].pad(), eos=self.dicts["text"].eos(), break_mode=self.args.sample_break_mode, document_sep_len=0, ) # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) bart_dataset = PrependTokenDataset(bart_dataset, self.dicts["text"].bos()) bart_dataset = AppendTokenDataset(bart_dataset, self.dicts["text"].eos()) mask_whole_words = ( get_whole_word_mask(self.args, self.dicts["text"]) if self.args.mask_length != "subword" else None ) self.args.bert_weight = getattr(self.args, "bert_weight", 0.0) pretrain_datasets.append( TextPretrainDataset( bart_dataset, bart_dataset.sizes, self.dicts["text"], self.mask_idx, mask_whole_words, shuffle=self.args.shuffle_instance, seed=self.seed, args=self.args, iid_noise_target=self.args.iid_noise_target, uni_mask_idxs=self.uni_mask_idxs if self.args.iid_noise_target else None, ) ) sample_ratios.append(sum(pretrain_datasets[1].sizes)) logger.info( "Task: {0}, Loaded {1} samples of denoising_dataset".format( 'bart', len(pretrain_datasets[1]), ) ) logger.info('token ratio is ' + str(sample_ratios)) if self.args.batch_ratio is not None: batch_ratio = eval(self.args.batch_ratio) assert len(batch_ratio) == len(sample_ratios) sample_ratios = [sample_ratios[i] / batch_ratio[i] for i in range(len(sample_ratios))] else: batch_ratio = None max_size = max(sample_ratios) sample_ratios = [max_size / r for r in sample_ratios] if hasattr(self.args, "sample_ratios") and self.args.sample_ratios is not None: sample_ratios = eval(self.args.sample_ratios) if is_train_split: self.datasets[split] = MultitaskDataset( pretrain_datasets, sample_ratios, batch_ratio ) else: self.datasets[split] = MultitaskDataset( pretrain_datasets, batch_ratio=batch_ratio ) def train_step( self, sample, model, criterion, optimizer, update_num, ignore_grad=False ): model.train() model.set_num_updates(update_num) # Junyi: not use sample_size, but normalize the loss locally agg_loss, agg_sample_size, agg_logging_output = 0.0, 1.0, {} agg_logging_output['sample_size'] = 1 def forward_backward(model, samples, weight=1.0): nonlocal agg_loss, agg_logging_output if samples is None or len(samples) == 0: return loss, sample_size, logging_output = criterion(model, samples) if ignore_grad: loss *= 0 else: loss *= weight loss = loss / sample_size optimizer.backward(loss) agg_loss += loss.detach().item() # # TODO make summing of the sample sizes configurable for k in logging_output: if k == 'ntokens' or k == 'nsentences': if k not in agg_logging_output: agg_logging_output[k] = 0 agg_logging_output[k] += logging_output[k] # continue # agg_logging_output[k] += logging_output[k] # agg_logging_output[task_name] += logging_output[k] agg_logging_output[samples['task_name']] = logging_output forward_backward(model, sample) agg_logging_output["loss"] = agg_loss return agg_loss, agg_sample_size, agg_logging_output def valid_step(self, sample, model, criterion): model.eval() with torch.no_grad(): from collections import defaultdict agg_loss, agg_sample_size, agg_logging_output = 0.0, 1.0, defaultdict(float) agg_logging_output['sample_size'] = 1 loss, sample_size, logging_output = criterion(model, sample) loss = loss / sample_size # agg_loss += loss.data.item() if isinstance(loss, torch.Tensor) else loss agg_loss += loss.item() if isinstance(loss, torch.Tensor) else loss agg_logging_output[sample['task_name']] = logging_output agg_logging_output["loss"] = agg_loss return agg_loss, agg_sample_size, agg_logging_output @property def target_dictionary(self): return self.dicts["text"] @property def source_dictionary(self): return None def build_model(self, args): try: args.input_feat_per_channel = self.config.input_feat_per_channel args.input_channels = self.config.input_channels except Exception as e: args.input_feat_per_channel = 80 args.input_channels = 1 logger.info(f"Cannot set input_feat_per_channel, input_channels, since: ") logger.warn(e) logger.info(f"Set to: {args.input_feat_per_channel} and {args.input_channels}") args.speech_odim = args.input_feat_per_channel * args.input_channels args.label_rates = self.args.label_rates args.sample_rate = self.args.sample_rate self.args.reduction_factor = args.reduction_factor return super(ArTSTTask, self).build_model(args) def build_generator( self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, ): from artst.sequence_generator import SequenceGenerator extra_gen_cls_kwargs = { "ctc_weight": self.args.ctc_weight, **extra_gen_cls_kwargs } return super().build_generator( models, args, seq_gen_cls=SequenceGenerator, extra_gen_cls_kwargs=extra_gen_cls_kwargs ) def build_tokenizer(self, args): if self.config is None: logger.info(f"pre-tokenizer: None") return encoders.build_tokenizer(Namespace(**{"tokenizer": None})) else: logger.info(f"pre-tokenizer: {self.config.pre_tokenizer}") return encoders.build_tokenizer(Namespace(**self.config.pre_tokenizer)) def build_bpe(self, args): if self.config is not None: logger.info(f"tokenizer: {self.config.bpe_tokenizer}") return encoders.build_bpe(Namespace(**self.config.bpe_tokenizer)) else: logger.info(f"tokenizer: {self.args.bpe_tokenizer}") return encoders.build_bpe(Namespace(**{"bpe": "sentencepiece", "sentencepiece_model": self.args.bpe_tokenizer})) def generate_class(self, models, net_input, prefix_tokens, **kwargs): with torch.no_grad(): encoder_input = { k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name" } encoder_input.update(kwargs) encoder_input.update({"prev_output_tokens": prefix_tokens}) return models[0].generate_class(**encoder_input) def generate_speech(self, models, net_input, **kwargs): with torch.no_grad(): encoder_input = { k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name" } encoder_input.update(kwargs) return models[0].generate_speech(**encoder_input) def inference_t2s( self, models, sample ): with torch.no_grad(): xs = sample['net_input']['src_tokens'] spkemb = sample['net_input']['spkembs'] return models[0].inference(xs, spkemb) def inference_s2s( self, models, sample, force_equal_length=False ): with torch.no_grad(): x = sample['net_input']['src_tokens'] xlen = sample['net_input']['src_lengths'] spkemb = sample['net_input']['spkembs'] prev_output_tokens = sample['net_input']['prev_output_tokens'] padding_mask = sample['net_input']['padding_mask'] tgt_lengths = sample['net_input']['tgt_lengths'] return models[0].inference_s2s(x, xlen, spkemb, prev_output_tokens, tgt_lengths, force_equal_length=force_equal_length, padding_mask=padding_mask) def inference_s2c( self, models, sample ): with torch.no_grad(): x = sample['net_input']['src_tokens'] xlen = sample['net_input']['src_lengths'] prev_output_tokens = sample['net_input']['prev_output_tokens'] padding_mask = sample['net_input']['padding_mask'] assert prev_output_tokens.size(1) == 1, prev_output_tokens.size() return models[0].inference_s2c(x, xlen, prev_output_tokens, padding_mask=padding_mask) def filter_indices_by_size( self, indices, dataset, max_positions=None, ignore_invalid_inputs=False ): """ Filter examples that are too large Args: indices (np.array): original array of sample indices dataset (~fairseq.data.FairseqDataset): dataset to batch max_positions (optional): max sentence length supported by the model (default: None). ignore_invalid_inputs (bool, optional): don't raise Exception for sentences that are too long (default: False). Returns: np.array: array of filtered sample indices """ indices, ignored = dataset.filter_indices_by_size( indices, self.max_pos ) return indices