herwoww's picture
first upload
1547a56
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
import logging
import os.path as op
from argparse import Namespace
from collections import OrderedDict
import torch
from fairseq.data import (
Dictionary,
encoders,
PrependTokenDataset,
AppendTokenDataset,
data_utils,
StripTokenDataset,
TokenBlockDataset,
)
from fairseq.data.encoders.utils import get_whole_word_mask
from fairseq import utils
from artst.data.multitask_dataset import MultitaskDataset
from artst.data.speech_to_text_dataset import SpeechToTextDataset
from artst.data.text_to_speech_dataset import TextToSpeechDataset
from artst.data.speech_to_speech_dataset import SpeechToSpeechDataset
from artst.data.speech_to_class_dataset import SpeechToClassDataset
from artst.data.speech_dataset import SpeechPretrainDataset
from artst.data.text_dataset import TextPretrainDataset
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.tasks import LegacyFairseqTask, register_task
from fairseq.tasks.hubert_pretraining import LabelEncoder
logger = logging.getLogger(__name__)
TASK_NAME = ["s2t", "t2s", "s2s", "s2c", "pretrain"]
@register_task("artst")
class ArTSTTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
parser.add_argument("data", help="manifest root path")
parser.add_argument(
"--config-yaml",
type=str,
default="config.yaml",
help="Configuration YAML filename (under manifest root)",
)
parser.add_argument(
"--max-speech-sample-size",
default=None,
type=int,
metavar="N",
help="max speech sample size",
)
parser.add_argument(
"--min-speech-sample-size",
default=None,
type=int,
metavar="N",
help="min speech sample size",
)
parser.add_argument(
"--max-speech-positions",
default=4000,
type=int,
metavar="N",
help="max number of tokens in the source sequence",
)
parser.add_argument(
"--max-text-positions",
default=450,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
parser.add_argument(
'--t5-task',
choices=TASK_NAME,
help='task for training'
)
parser.add_argument(
"--bpe-tokenizer",
type=str,
default=None,
help="bpe tokenizer for s2t",
)
# Speaker Identification (SID)
parser.add_argument(
"--finetune-from-modules",
default=None,
# choices=[
# "encoder-decoder", "encoder", "decoder",
# "speech_encoder_prenet-encoder-decoder-text_decoder_prenet-text_decoder_postnet", # ASR, T5 SID
# "speech_encoder_prenet-encoder-decoder-text_decoder_prenet-speaker_decoder_postnet", # SID
# "speech_encoder_prenet-encoder-decoder-speech_decoder_prenet-speech_decoder_postnet", # VC, SE
# "text_encoder_prenet-encoder-decoder-speech_decoder_prenet-speech_decoder_postnet", # TTS
# ],
help="If set, using part modules of finetune model.",
)
parser.add_argument(
"--finetune-out-of-modules",
default=None,
# choices=[
# "speaker_decoder_postnet", # SID
# "speech_decoder_postnet", # SE with reduction factor 1
# ],
help="If set, remove part modules of finetune model.",
)
# BART
parser.add_argument(
"--shorten-method",
default="none",
choices=["none", "truncate", "random_crop"],
help="if not none, shorten sequences that exceed --tokens-per-sample",
)
parser.add_argument(
"--shorten-data-split-list",
default="",
help="comma-separated list of dataset splits to apply shortening to, "
'e.g., "train,valid" (default: all dataset splits)',
)
parser.add_argument(
"--tokens-per-sample",
default=512,
type=int,
help="max number of total tokens over all segments"
" per sample for dataset",
)
parser.add_argument(
"--sample-break-mode",
default="eos",
type=str,
help="mode for breaking sentence",
)
parser.add_argument(
"--mask",
default=0.3,
type=float,
help="fraction of words/subwords that will be masked",
)
parser.add_argument(
"--mask-random",
default=0.1,
type=float,
help="instead of using [MASK], use random token this often",
)
parser.add_argument(
"--insert",
default=0.0,
type=float,
help="insert this percentage of additional random tokens",
)
parser.add_argument(
"--permute",
default=0.0,
type=float,
help="take this proportion of subwords and permute them",
)
parser.add_argument(
"--rotate",
default=0.0,
type=float,
help="rotate this proportion of inputs",
)
parser.add_argument(
"--poisson-lambda",
default=3.5,
type=float,
help="randomly shuffle sentences for this proportion of inputs",
)
parser.add_argument(
"--permute-sentences",
default=0.0,
type=float,
help="shuffle this proportion of sentences in all inputs",
)
# parser.add_argument(
# "--mask-length",
# default="span-poisson",
# type=str,
# choices=["subword", "word", "span-poisson"],
# help="mask length to choose",
# )
parser.add_argument(
"--replace-length",
default=1,
type=int,
help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)",
)
parser.add_argument(
"--iid-noise-target",
action="store_true",
help="whether to use t5 form target",
)
# Hubert
parser.add_argument(
"--hubert-labels",
nargs="*",
type=str,
default=['km'],
help="extension of the label files to load, frame-level labels for pre-training, and sequence-level label for fine-tuning",
)
parser.add_argument(
"--hubert-label-dir",
type=str,
default=None,
help="if set, looks for labels in this directory instead",
)
parser.add_argument(
"--sample-rate",
default=100,
type=float,
help="target sample rate. audio files will be up/down sampled to this rate",
)
parser.add_argument(
"--label-rates",
default=-1,
type=float,
help="if set, looks for labels in this directory instead",
)
parser.add_argument(
"--normalize",
action="store_true",
help="if set, normalizes input to have 0 mean and unit variance",
)
parser.add_argument(
"--enable-padding",
action="store_true",
help="pad shorter samples instead of cropping",
)
parser.add_argument(
"--pad-audio",
action="store_true",
help="pad audio to the longest one in the batch if true",
)
parser.add_argument(
"--random-crop",
action="store_true",
help="always crop from the beginning if false",
)
parser.add_argument(
"--single-target",
action="store_true",
help="if set, AddTargetDatasets outputs same keys "
"as AddTargetDataset",
)
parser.add_argument(
"--batch-ratio",
default=None,
type=str,
help="ratio of bach size for each dataset",
)
parser.add_argument(
"--sample-ratios",
default=None,
type=str,
help="ratio of sample for each dataset",
)
parser.add_argument(
"--ctc-weight",
type=float,
default=0.0,
help="ctc weight for inference",
)
parser.add_argument(
"--inference-speech",
type=bool,
default=False,
help="inference for TTS",
)
def __init__(self, args, dicts, config):
super().__init__(args)
self.dicts = dicts
self.config = config
self.t5_task = args.t5_task
# Used for filter size
if self.t5_task in ['s2t', 't2s', 's2s', 's2c']:
self.max_pos = [self.args.max_speech_positions * 256]
elif self.t5_task == 'pretrain':
self.max_pos = [self.args.max_speech_positions * 256, self.args.max_text_positions]
self.mask_idx = self.dicts["text"].add_symbol("<mask>")
# add blank token for ctc
# if args.ctc_weight > 0:
self.blank_symbol_idx = self.dicts["text"].add_symbol("<ctc_blank>")
self.blank_symbol = "<ctc_blank>"
# add mask token
if hasattr(args, "iid_noise_target") and args.iid_noise_target:
self.uni_mask_idxs = []
for i in range(600):
self.uni_mask_idxs.append(self.dicts["text"].add_symbol("<mask>" + str(i)))
self.uni_mask_idxs = torch.tensor(self.uni_mask_idxs)
self.seed = args.seed
@classmethod
def setup_task(cls, args, **kwargs):
# load dictionaries and config
dicts = OrderedDict()
if args.t5_task == 'pretrain' and not hasattr(args, "shuffle_instance"):
args.shuffle_instance = False
# Prepare config
config = None
logger.info('No config file for ' + args.t5_task)
if args.t5_task == "pretrain":
dicts["hubert"] = [Dictionary.load(f"{args.hubert_label_dir}/dict.{label}.txt") for label in args.hubert_labels]
dicts["text"] = Dictionary.load(op.join(args.data, "dict.txt"))
else:
if config is None:
dicts["text"] = Dictionary.load(op.join(args.data, "dict.txt"))
else:
dicts["text"] = Dictionary.load(op.join(args.data, config.vocab_filename))
return cls(args, dicts, config)
def build_criterion(self, args):
from fairseq import criterions
return criterions.build_criterion(args, self)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
sample_ratios = []
if self.t5_task == "s2t":
## For speech to text task
bpe_tokenizer = self.build_bpe(self.args)
manifest = f"{self.args.data}/{split}.tsv"
procs = [LabelEncoder(self.dicts["text"])]
paths = [f"{self.args.hubert_label_dir}/{split}.txt"]
# Hawau: view dataset...
logger.info(f"Manifest: {manifest}")
# logger.info(f"Paths: {paths}")
self.datasets[split] = SpeechToTextDataset(
manifest,
sample_rate=self.args.sample_rate,
label_paths=paths,
label_processors=procs,
max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size,
min_keep_sample_size=self.args.min_speech_sample_size,
normalize=self.args.normalize,
store_labels=False,
tgt_dict=self.dicts["text"],
tokenizer=bpe_tokenizer,
)
elif self.t5_task == "t2s":
## For text to speech task
from fairseq.data import ConcatDataset
bpe_tokenizer = self.build_bpe(self.args)
procs = [LabelEncoder(self.dicts["text"])]
t2s_datasets = [
TextToSpeechDataset(
manifest_path=f"{self.args.data}/{name}.tsv",
sample_rate=self.args.sample_rate,
label_paths=[f"{self.args.hubert_label_dir}/{name}.txt"],
label_processors=procs,
max_keep_sample_size=self.max_pos[0],
normalize=self.args.normalize,
store_labels=False,
src_dict=self.dicts["text"],
tokenizer=bpe_tokenizer,
reduction_factor=self.args.reduction_factor,
inference=self.args.inference_speech,
)
for name in split.split(",")
]
self.datasets[split] = ConcatDataset(t2s_datasets) if len(t2s_datasets) > 1 else t2s_datasets[0]
elif self.t5_task == "s2s":
manifest = f"{self.args.data}/{split}.tsv"
self.datasets[split] = SpeechToSpeechDataset(
manifest_path=manifest,
sample_rate=self.args.sample_rate,
max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size,
min_keep_sample_size=self.args.min_speech_sample_size,
normalize=self.args.normalize,
reduction_factor=self.args.reduction_factor,
)
elif self.t5_task == "s2c":
is_train_split = ("train" in split)
is_valid_split = ("valid" in split)
if is_train_split:
max_length = 51200
elif is_valid_split:
max_length = 76800
else:
max_length = 2560000
manifest = op.join(f"{self.args.data}", f"{split}.tsv")
procs = LabelEncoder(self.dicts["text"]) # map speaker to id
self.datasets[split] = SpeechToClassDataset(
manifest_path=manifest,
sample_rate=self.args.sample_rate,
label_processors=procs,
max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size,
min_keep_sample_size=self.args.min_speech_sample_size,
normalize=self.args.normalize,
tgt_dict=self.dicts["text"],
max_length=max_length
)
elif self.t5_task == "pretrain":
is_train_split = ("train" in split)
pretrain_datasets = []
speech_split, text_split = split.split('|')
## Speech pre-train
manifest = f"{self.args.data}/{speech_split}.tsv"
dicts = self.dicts["hubert"]
pad_list = [dict.pad() for dict in dicts]
eos_list = [dict.eos() for dict in dicts]
procs = [LabelEncoder(dict) for dict in dicts]
paths = [
f"{self.args.hubert_label_dir}/{speech_split}.{l}" for l in self.args.hubert_labels
]
# hubert v1: pad_audio=True, random_crop=False;
self.args.dec_weight = getattr(self.args, "dec_weight", 1.0)
pretrain_datasets.append(
SpeechPretrainDataset(
manifest,
sample_rate=self.args.sample_rate,
label_paths=paths,
label_rates=self.args.label_rates,
pad_list=pad_list,
eos_list=eos_list,
label_processors=procs,
max_keep_sample_size=None,
min_keep_sample_size=32000,
max_sample_size=self.args.max_speech_sample_size,
pad_audio=self.args.pad_audio,
normalize=self.args.normalize,
store_labels=False,
random_crop=self.args.random_crop,
single_target=self.args.single_target,
reduction_factor=self.args.reduction_factor,
)
)
sample_ratios.append(sum([pretrain_datasets[0].size(i) for i in range(len(pretrain_datasets[0]))]))
## Text pre-train
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
print(f"Loading {text_split} from data_path={data_path}")
split_path = op.join(data_path, text_split)
print(f"split_path={split_path}")
bart_dataset = data_utils.load_indexed_dataset(
split_path,
self.dicts["text"],
self.args.dataset_impl,
combine=combine,
)
if bart_dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(text_split, split_path)
)
bart_dataset = StripTokenDataset(bart_dataset, self.dicts["text"].eos())
bart_dataset = maybe_shorten_dataset(
bart_dataset,
text_split,
self.args.shorten_data_split_list,
self.args.shorten_method,
self.args.tokens_per_sample,
self.args.seed,
)
# create continuous blocks of tokens
bart_dataset = TokenBlockDataset(
bart_dataset,
bart_dataset.sizes,
self.args.tokens_per_sample - 2, # one less for <s> and one for </s>
pad=self.dicts["text"].pad(),
eos=self.dicts["text"].eos(),
break_mode=self.args.sample_break_mode,
document_sep_len=0,
)
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
bart_dataset = PrependTokenDataset(bart_dataset, self.dicts["text"].bos())
bart_dataset = AppendTokenDataset(bart_dataset, self.dicts["text"].eos())
mask_whole_words = (
get_whole_word_mask(self.args, self.dicts["text"])
if self.args.mask_length != "subword"
else None
)
self.args.bert_weight = getattr(self.args, "bert_weight", 0.0)
pretrain_datasets.append(
TextPretrainDataset(
bart_dataset,
bart_dataset.sizes,
self.dicts["text"],
self.mask_idx,
mask_whole_words,
shuffle=self.args.shuffle_instance,
seed=self.seed,
args=self.args,
iid_noise_target=self.args.iid_noise_target,
uni_mask_idxs=self.uni_mask_idxs if self.args.iid_noise_target else None,
)
)
sample_ratios.append(sum(pretrain_datasets[1].sizes))
logger.info(
"Task: {0}, Loaded {1} samples of denoising_dataset".format(
'bart',
len(pretrain_datasets[1]),
)
)
logger.info('token ratio is ' + str(sample_ratios))
if self.args.batch_ratio is not None:
batch_ratio = eval(self.args.batch_ratio)
assert len(batch_ratio) == len(sample_ratios)
sample_ratios = [sample_ratios[i] / batch_ratio[i] for i in range(len(sample_ratios))]
else:
batch_ratio = None
max_size = max(sample_ratios)
sample_ratios = [max_size / r for r in sample_ratios]
if hasattr(self.args, "sample_ratios") and self.args.sample_ratios is not None:
sample_ratios = eval(self.args.sample_ratios)
if is_train_split:
self.datasets[split] = MultitaskDataset(
pretrain_datasets, sample_ratios, batch_ratio
)
else:
self.datasets[split] = MultitaskDataset(
pretrain_datasets, batch_ratio=batch_ratio
)
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
model.train()
model.set_num_updates(update_num)
# Junyi: not use sample_size, but normalize the loss locally
agg_loss, agg_sample_size, agg_logging_output = 0.0, 1.0, {}
agg_logging_output['sample_size'] = 1
def forward_backward(model, samples, weight=1.0):
nonlocal agg_loss, agg_logging_output
if samples is None or len(samples) == 0:
return
loss, sample_size, logging_output = criterion(model, samples)
if ignore_grad:
loss *= 0
else:
loss *= weight
loss = loss / sample_size
optimizer.backward(loss)
agg_loss += loss.detach().item()
# # TODO make summing of the sample sizes configurable
for k in logging_output:
if k == 'ntokens' or k == 'nsentences':
if k not in agg_logging_output:
agg_logging_output[k] = 0
agg_logging_output[k] += logging_output[k]
# continue
# agg_logging_output[k] += logging_output[k]
# agg_logging_output[task_name] += logging_output[k]
agg_logging_output[samples['task_name']] = logging_output
forward_backward(model, sample)
agg_logging_output["loss"] = agg_loss
return agg_loss, agg_sample_size, agg_logging_output
def valid_step(self, sample, model, criterion):
model.eval()
with torch.no_grad():
from collections import defaultdict
agg_loss, agg_sample_size, agg_logging_output = 0.0, 1.0, defaultdict(float)
agg_logging_output['sample_size'] = 1
loss, sample_size, logging_output = criterion(model, sample)
loss = loss / sample_size
# agg_loss += loss.data.item() if isinstance(loss, torch.Tensor) else loss
agg_loss += loss.item() if isinstance(loss, torch.Tensor) else loss
agg_logging_output[sample['task_name']] = logging_output
agg_logging_output["loss"] = agg_loss
return agg_loss, agg_sample_size, agg_logging_output
@property
def target_dictionary(self):
return self.dicts["text"]
@property
def source_dictionary(self):
return None
def build_model(self, args):
try:
args.input_feat_per_channel = self.config.input_feat_per_channel
args.input_channels = self.config.input_channels
except Exception as e:
args.input_feat_per_channel = 80
args.input_channels = 1
logger.info(f"Cannot set input_feat_per_channel, input_channels, since: ")
logger.warn(e)
logger.info(f"Set to: {args.input_feat_per_channel} and {args.input_channels}")
args.speech_odim = args.input_feat_per_channel * args.input_channels
args.label_rates = self.args.label_rates
args.sample_rate = self.args.sample_rate
self.args.reduction_factor = args.reduction_factor
return super(ArTSTTask, self).build_model(args)
def build_generator(
self,
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=None,
):
from artst.sequence_generator import SequenceGenerator
extra_gen_cls_kwargs = {
"ctc_weight": self.args.ctc_weight,
**extra_gen_cls_kwargs
}
return super().build_generator(
models, args, seq_gen_cls=SequenceGenerator, extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
def build_tokenizer(self, args):
if self.config is None:
logger.info(f"pre-tokenizer: None")
return encoders.build_tokenizer(Namespace(**{"tokenizer": None}))
else:
logger.info(f"pre-tokenizer: {self.config.pre_tokenizer}")
return encoders.build_tokenizer(Namespace(**self.config.pre_tokenizer))
def build_bpe(self, args):
if self.config is not None:
logger.info(f"tokenizer: {self.config.bpe_tokenizer}")
return encoders.build_bpe(Namespace(**self.config.bpe_tokenizer))
else:
logger.info(f"tokenizer: {self.args.bpe_tokenizer}")
return encoders.build_bpe(Namespace(**{"bpe": "sentencepiece", "sentencepiece_model": self.args.bpe_tokenizer}))
def generate_class(self, models, net_input, prefix_tokens, **kwargs):
with torch.no_grad():
encoder_input = {
k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name"
}
encoder_input.update(kwargs)
encoder_input.update({"prev_output_tokens": prefix_tokens})
return models[0].generate_class(**encoder_input)
def generate_speech(self, models, net_input, **kwargs):
with torch.no_grad():
encoder_input = {
k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name"
}
encoder_input.update(kwargs)
return models[0].generate_speech(**encoder_input)
def inference_t2s(
self, models, sample
):
with torch.no_grad():
xs = sample['net_input']['src_tokens']
spkemb = sample['net_input']['spkembs']
return models[0].inference(xs, spkemb)
def inference_s2s(
self, models, sample, force_equal_length=False
):
with torch.no_grad():
x = sample['net_input']['src_tokens']
xlen = sample['net_input']['src_lengths']
spkemb = sample['net_input']['spkembs']
prev_output_tokens = sample['net_input']['prev_output_tokens']
padding_mask = sample['net_input']['padding_mask']
tgt_lengths = sample['net_input']['tgt_lengths']
return models[0].inference_s2s(x, xlen, spkemb, prev_output_tokens, tgt_lengths, force_equal_length=force_equal_length, padding_mask=padding_mask)
def inference_s2c(
self, models, sample
):
with torch.no_grad():
x = sample['net_input']['src_tokens']
xlen = sample['net_input']['src_lengths']
prev_output_tokens = sample['net_input']['prev_output_tokens']
padding_mask = sample['net_input']['padding_mask']
assert prev_output_tokens.size(1) == 1, prev_output_tokens.size()
return models[0].inference_s2c(x, xlen, prev_output_tokens, padding_mask=padding_mask)
def filter_indices_by_size(
self, indices, dataset, max_positions=None, ignore_invalid_inputs=False
):
"""
Filter examples that are too large
Args:
indices (np.array): original array of sample indices
dataset (~fairseq.data.FairseqDataset): dataset to batch
max_positions (optional): max sentence length supported by the
model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long (default: False).
Returns:
np.array: array of filtered sample indices
"""
indices, ignored = dataset.filter_indices_by_size(
indices,
self.max_pos
)
return indices