JustinLin610
update
8437114
raw history blame
No virus
5.84 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from pathlib import Path
from argparse import Namespace
from fairseq.data import Dictionary, encoders
from fairseq.data.audio.speech_to_text_dataset import (
S2TDataConfig,
SpeechToTextDataset,
SpeechToTextDatasetCreator,
get_features_or_waveform
)
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("speech_to_text")
class SpeechToTextTask(LegacyFairseqTask):
@classmethod
def add_args(cls, parser):
parser.add_argument("data", help="manifest root path")
parser.add_argument(
"--config-yaml",
type=str,
default="config.yaml",
help="Configuration YAML filename (under manifest root)",
)
parser.add_argument(
"--max-source-positions",
default=6000,
type=int,
metavar="N",
help="max number of tokens in the source sequence",
)
parser.add_argument(
"--max-target-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
def __init__(self, args, tgt_dict):
super().__init__(args)
self.tgt_dict = tgt_dict
self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml)
self.speaker_to_id = self._get_speaker_to_id()
def _get_speaker_to_id(self):
speaker_to_id = None
speaker_set_filename = self.data_cfg.config.get("speaker_set_filename")
if speaker_set_filename is not None:
speaker_set_path = Path(self.args.data) / speaker_set_filename
with open(speaker_set_path) as f:
speaker_to_id = {r.strip(): i for i, r in enumerate(f)}
return speaker_to_id
@classmethod
def setup_task(cls, args, **kwargs):
data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml)
dict_path = Path(args.data) / data_cfg.vocab_filename
if not dict_path.is_file():
raise FileNotFoundError(f"Dict not found: {dict_path.as_posix()}")
tgt_dict = Dictionary.load(dict_path.as_posix())
logger.info(
f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}"
)
if getattr(args, "train_subset", None) is not None:
if not all(s.startswith("train") for s in args.train_subset.split(",")):
raise ValueError('Train splits should be named like "train*".')
return cls(args, tgt_dict)
def build_criterion(self, args):
from fairseq import criterions
if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1:
raise ValueError(
'Please set "--ignore-prefix-size 1" since '
"target language ID token is prepended as BOS."
)
return criterions.build_criterion(args, self)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
is_train_split = split.startswith("train")
pre_tokenizer = self.build_tokenizer(self.args)
bpe_tokenizer = self.build_bpe(self.args)
self.datasets[split] = SpeechToTextDatasetCreator.from_tsv(
self.args.data,
self.data_cfg,
split,
self.tgt_dict,
pre_tokenizer,
bpe_tokenizer,
is_train_split=is_train_split,
epoch=epoch,
seed=self.args.seed,
speaker_to_id=self.speaker_to_id
)
@property
def target_dictionary(self):
return self.tgt_dict
@property
def source_dictionary(self):
return None
def max_positions(self):
return self.args.max_source_positions, self.args.max_target_positions
def build_model(self, args):
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel
args.input_channels = self.data_cfg.input_channels
args.speaker_to_id = self.speaker_to_id
return super(SpeechToTextTask, self).build_model(args)
def build_generator(
self,
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=None,
):
if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1:
raise ValueError(
'Please set "--prefix-size 1" since '
"target language ID token is prepended as BOS."
)
lang_token_ids = {
i
for s, i in self.tgt_dict.indices.items()
if SpeechToTextDataset.is_lang_tag(s)
}
if extra_gen_cls_kwargs is None:
extra_gen_cls_kwargs = {}
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids
return super().build_generator(
models, args, seq_gen_cls=None,
extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
def build_tokenizer(self, args):
logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}")
return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer))
def build_bpe(self, args):
logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}")
return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer))
def get_interactive_tokens_and_lengths(self, lines, encode_fn):
n_frames = [get_features_or_waveform(p).shape[0] for p in lines]
return lines, n_frames
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
return SpeechToTextDataset(
"interactive", False, self.data_cfg, src_tokens, src_lengths
)