#!/usr/bin/env python3 """Script to run the inference of text-to-speeech model.""" import argparse import logging import shutil import sys import time from pathlib import Path from typing import Any, Dict, Optional, Sequence, Tuple, Union import numpy as np import _soundfile.py as sf import torch from packaging.version import parse as V from typeguard import typechecked from espnet2.fileio.npy_scp import NpyScpWriter from espnet2.gan_tts.vits import VITS from espnet2.tasks.tts import TTSTask from espnet2.torch_utils.device_funcs import to_device from espnet2.torch_utils.set_all_random_seed import set_all_random_seed from espnet2.tts.fastspeech import FastSpeech from espnet2.tts.fastspeech2 import FastSpeech2 from espnet2.tts.tacotron2 import Tacotron2 from espnet2.tts.transformer import Transformer from espnet2.tts.utils import DurationCalculator from espnet2.utils import config_argparse from espnet2.utils.types import str2bool, str2triple_str, str_or_none from espnet.utils.cli_utils import get_commandline_args class Text2Speech: """Text2Speech class. Examples: >>> from espnet2.bin.tts_inference import Text2Speech >>> # Case 1: Load the local model and use Griffin-Lim vocoder >>> text2speech = Text2Speech( >>> train_config="/path/to/config.yml", >>> model_file="/path/to/model.pth", >>> ) >>> # Case 2: Load the local model and the pretrained vocoder >>> text2speech = Text2Speech.from_pretrained( >>> train_config="/path/to/config.yml", >>> model_file="/path/to/model.pth", >>> vocoder_tag="kan-bayashi/ljspeech_tacotron2", >>> ) >>> # Case 3: Load the pretrained model and use Griffin-Lim vocoder >>> text2speech = Text2Speech.from_pretrained( >>> model_tag="kan-bayashi/ljspeech_tacotron2", >>> ) >>> # Case 4: Load the pretrained model and the pretrained vocoder >>> text2speech = Text2Speech.from_pretrained( >>> model_tag="kan-bayashi/ljspeech_tacotron2", >>> vocoder_tag="parallel_wavegan/ljspeech_parallel_wavegan.v1", >>> ) >>> # Run inference and save as wav file >>> import soundfile as sf >>> wav = text2speech("Hello, World")["wav"] >>> sf.write("out.wav", wav.numpy(), text2speech.fs, "PCM_16") """ @typechecked def __init__( self, train_config: Union[Path, str, None] = None, model_file: Union[Path, str, None] = None, threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 10.0, use_teacher_forcing: bool = False, use_att_constraint: bool = False, backward_window: int = 1, forward_window: int = 3, speed_control_alpha: float = 1.0, noise_scale: float = 0.667, noise_scale_dur: float = 0.8, vocoder_config: Union[Path, str, None] = None, vocoder_file: Union[Path, str, None] = None, dtype: str = "float32", device: str = "cpu", seed: int = 777, always_fix_seed: bool = False, prefer_normalized_feats: bool = False, ): """Initialize Text2Speech module.""" # setup model model, train_args = TTSTask.build_model_from_file( train_config, model_file, device ) model.to(dtype=getattr(torch, dtype)).eval() self.device = device self.dtype = dtype self.train_args = train_args self.model = model self.tts = model.tts self.normalize = model.normalize self.feats_extract = model.feats_extract self.duration_calculator = DurationCalculator() self.preprocess_fn = TTSTask.build_preprocess_fn(train_args, False) self.use_teacher_forcing = use_teacher_forcing self.seed = seed self.always_fix_seed = always_fix_seed self.vocoder = None self.prefer_normalized_feats = prefer_normalized_feats if self.tts.require_vocoder: vocoder = TTSTask.build_vocoder_from_file( vocoder_config, vocoder_file, model, device ) if isinstance(vocoder, torch.nn.Module): vocoder.to(dtype=getattr(torch, dtype)).eval() self.vocoder = vocoder logging.info(f"Extractor:\n{self.feats_extract}") logging.info(f"Normalizer:\n{self.normalize}") logging.info(f"TTS:\n{self.tts}") if self.vocoder is not None: logging.info(f"Vocoder:\n{self.vocoder}") # setup decoding config decode_conf = {} decode_conf.update(use_teacher_forcing=use_teacher_forcing) if isinstance(self.tts, (Tacotron2, Transformer)): decode_conf.update( threshold=threshold, maxlenratio=maxlenratio, minlenratio=minlenratio, ) if isinstance(self.tts, Tacotron2): decode_conf.update( use_att_constraint=use_att_constraint, forward_window=forward_window, backward_window=backward_window, ) if isinstance(self.tts, (FastSpeech, FastSpeech2, VITS)): decode_conf.update(alpha=speed_control_alpha) if isinstance(self.tts, VITS): decode_conf.update( noise_scale=noise_scale, noise_scale_dur=noise_scale_dur, ) self.decode_conf = decode_conf @torch.no_grad() @typechecked def __call__( self, text: Union[str, torch.Tensor, np.ndarray], speech: Union[torch.Tensor, np.ndarray, None] = None, durations: Union[torch.Tensor, np.ndarray, None] = None, spembs: Union[torch.Tensor, np.ndarray, None] = None, sids: Union[torch.Tensor, np.ndarray, None] = None, lids: Union[torch.Tensor, np.ndarray, None] = None, decode_conf: Optional[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """Run text-to-speech.""" # check inputs if self.use_speech and speech is None: raise RuntimeError("Missing required argument: 'speech'") if self.use_sids and sids is None: raise RuntimeError("Missing required argument: 'sids'") if self.use_lids and lids is None: raise RuntimeError("Missing required argument: 'lids'") if self.use_spembs and spembs is None: raise RuntimeError("Missing required argument: 'spembs'") # prepare batch if isinstance(text, str): text = self.preprocess_fn("", dict(text=text))["text"] batch = dict(text=text) if speech is not None: batch.update(speech=speech) if durations is not None: batch.update(durations=durations) if spembs is not None: batch.update(spembs=spembs) if sids is not None: batch.update(sids=sids) if lids is not None: batch.update(lids=lids) batch = to_device(batch, self.device) # overwrite the decode configs if provided cfg = self.decode_conf if decode_conf is not None: cfg = self.decode_conf.copy() cfg.update(decode_conf) # inference if self.always_fix_seed: set_all_random_seed(self.seed) output_dict = self.model.inference(**batch, **cfg) # calculate additional metrics if output_dict.get("att_w") is not None: duration, focus_rate = self.duration_calculator(output_dict["att_w"]) output_dict.update(duration=duration, focus_rate=focus_rate) # apply vocoder (mel-to-wav) if self.vocoder is not None: if ( self.prefer_normalized_feats or output_dict.get("feat_gen_denorm") is None ): input_feat = output_dict["feat_gen"] else: input_feat = output_dict["feat_gen_denorm"] wav = self.vocoder(input_feat) output_dict.update(wav=wav) return output_dict @property def fs(self) -> Optional[int]: """Return sampling rate.""" if hasattr(self.vocoder, "fs"): return self.vocoder.fs elif hasattr(self.tts, "fs"): return self.tts.fs else: return None @property def use_speech(self) -> bool: """Return speech is needed or not in the inference.""" return self.use_teacher_forcing or getattr(self.tts, "use_gst", False) @property def use_sids(self) -> bool: """Return sid is needed or not in the inference.""" return self.tts.spks is not None @property def use_lids(self) -> bool: """Return sid is needed or not in the inference.""" return self.tts.langs is not None @property def use_spembs(self) -> bool: """Return spemb is needed or not in the inference.""" return self.tts.spk_embed_dim is not None @staticmethod def from_pretrained( model_tag: Optional[str] = None, vocoder_tag: Optional[str] = None, **kwargs: Optional[Any], ): """Build Text2Speech instance from the pretrained model. Args: model_tag (Optional[str]): Model tag of the pretrained models. Currently, the tags of espnet_model_zoo are supported. vocoder_tag (Optional[str]): Vocoder tag of the pretrained vocoders. Currently, the tags of parallel_wavegan are supported, which should start with the prefix "parallel_wavegan/". Returns: Text2Speech: Text2Speech instance. """ if model_tag is not None: try: from espnet_model_zoo.downloader import ModelDownloader except ImportError: logging.error( "`espnet_model_zoo` is not installed. " "Please install via `pip install -U espnet_model_zoo`." ) raise d = ModelDownloader() kwargs.update(**d.download_and_unpack(model_tag)) if vocoder_tag is not None: if vocoder_tag.startswith("parallel_wavegan/"): try: from parallel_wavegan.utils import download_pretrained_model except ImportError: logging.error( "`parallel_wavegan` is not installed. " "Please install via `pip install -U parallel_wavegan`." ) raise from parallel_wavegan import __version__ # NOTE(kan-bayashi): Filelock download is supported from 0.5.2 assert V(__version__) > V("0.5.1"), ( "Please install the latest parallel_wavegan " "via `pip install -U parallel_wavegan`." ) vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "") vocoder_file = download_pretrained_model(vocoder_tag) vocoder_config = Path(vocoder_file).parent / "config.yml" kwargs.update(vocoder_config=vocoder_config, vocoder_file=vocoder_file) else: raise ValueError(f"{vocoder_tag} is unsupported format.") return Text2Speech(**kwargs) @typechecked def inference( output_dir: Union[Path, str], batch_size: int, dtype: str, ngpu: int, seed: int, num_workers: int, log_level: Union[int, str], data_path_and_name_and_type: Sequence[Tuple[str, str, str]], key_file: Optional[str], train_config: Optional[str], model_file: Optional[str], model_tag: Optional[str], threshold: float, minlenratio: float, maxlenratio: float, use_teacher_forcing: bool, use_att_constraint: bool, backward_window: int, forward_window: int, speed_control_alpha: float, noise_scale: float, noise_scale_dur: float, always_fix_seed: bool, allow_variable_data_keys: bool, vocoder_config: Optional[str], vocoder_file: Optional[str], vocoder_tag: Optional[str], ): """Run text-to-speech inference.""" if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) if ngpu >= 1: device = "cuda" else: device = "cpu" # 1. Set random-seed set_all_random_seed(seed) # 2. Build model text2speech_kwargs = dict( train_config=train_config, model_file=model_file, threshold=threshold, maxlenratio=maxlenratio, minlenratio=minlenratio, use_teacher_forcing=use_teacher_forcing, use_att_constraint=use_att_constraint, backward_window=backward_window, forward_window=forward_window, speed_control_alpha=speed_control_alpha, noise_scale=noise_scale, noise_scale_dur=noise_scale_dur, vocoder_config=vocoder_config, vocoder_file=vocoder_file, dtype=dtype, device=device, seed=seed, always_fix_seed=always_fix_seed, ) text2speech = Text2Speech.from_pretrained( model_tag=model_tag, vocoder_tag=vocoder_tag, **text2speech_kwargs, ) # 3. Build data-iterator if not text2speech.use_speech: data_path_and_name_and_type = list( filter(lambda x: x[1] != "speech", data_path_and_name_and_type) ) loader = TTSTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, preprocess_fn=TTSTask.build_preprocess_fn(text2speech.train_args, False), collate_fn=TTSTask.build_collate_fn(text2speech.train_args, False), allow_variable_data_keys=allow_variable_data_keys, inference=True, ) # 4. Start for-loop output_dir = Path(output_dir) (output_dir / "norm").mkdir(parents=True, exist_ok=True) (output_dir / "denorm").mkdir(parents=True, exist_ok=True) (output_dir / "speech_shape").mkdir(parents=True, exist_ok=True) (output_dir / "wav").mkdir(parents=True, exist_ok=True) (output_dir / "att_ws").mkdir(parents=True, exist_ok=True) (output_dir / "probs").mkdir(parents=True, exist_ok=True) (output_dir / "durations").mkdir(parents=True, exist_ok=True) (output_dir / "focus_rates").mkdir(parents=True, exist_ok=True) # Lazy load to avoid the backend error import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator with NpyScpWriter( output_dir / "norm", output_dir / "norm/feats.scp", ) as norm_writer, NpyScpWriter( output_dir / "denorm", output_dir / "denorm/feats.scp" ) as denorm_writer, open( output_dir / "speech_shape/speech_shape", "w" ) as shape_writer, open( output_dir / "durations/durations", "w" ) as duration_writer, open( output_dir / "focus_rates/focus_rates", "w" ) as focus_rate_writer: for idx, (keys, batch) in enumerate(loader, 1): assert isinstance(batch, dict), type(batch) assert all(isinstance(s, str) for s in keys), keys _bs = len(next(iter(batch.values()))) assert _bs == 1, _bs # Change to single sequence and remove *_length # because inference() requires 1-seq, not mini-batch. batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} start_time = time.perf_counter() output_dict = text2speech(**batch) key = keys[0] insize = next(iter(batch.values())).size(0) + 1 if output_dict.get("feat_gen") is not None: # standard text2mel model case feat_gen = output_dict["feat_gen"] logging.info( "inference speed = {:.1f} frames / sec.".format( int(feat_gen.size(0)) / (time.perf_counter() - start_time) ) ) logging.info(f"{key} (size:{insize}->{feat_gen.size(0)})") if feat_gen.size(0) == insize * maxlenratio: logging.warning(f"output length reaches maximum length ({key}).") norm_writer[key] = output_dict["feat_gen"].cpu().numpy() shape_writer.write( f"{key} " + ",".join(map(str, output_dict["feat_gen"].shape)) + "\n" ) if output_dict.get("feat_gen_denorm") is not None: denorm_writer[key] = output_dict["feat_gen_denorm"].cpu().numpy() else: # end-to-end text2wav model case wav = output_dict["wav"] logging.info( "inference speed = {:.1f} points / sec.".format( int(wav.size(0)) / (time.perf_counter() - start_time) ) ) logging.info(f"{key} (size:{insize}->{wav.size(0)})") if output_dict.get("duration") is not None: # Save duration and fucus rates duration_writer.write( f"{key} " + " ".join(map(str, output_dict["duration"].long().cpu().numpy())) + "\n" ) if output_dict.get("focus_rate") is not None: focus_rate_writer.write( f"{key} {float(output_dict['focus_rate']):.5f}\n" ) if output_dict.get("att_w") is not None: # Plot attention weight att_w = output_dict["att_w"].cpu().numpy() if att_w.ndim == 2: att_w = att_w[None][None] elif att_w.ndim != 4: raise RuntimeError(f"Must be 2 or 4 dimension: {att_w.ndim}") w, h = plt.figaspect(att_w.shape[0] / att_w.shape[1]) fig = plt.Figure( figsize=( w * 1.3 * min(att_w.shape[0], 2.5), h * 1.3 * min(att_w.shape[1], 2.5), ) ) fig.suptitle(f"{key}") axes = fig.subplots(att_w.shape[0], att_w.shape[1]) if len(att_w) == 1: axes = [[axes]] for ax, att_w in zip(axes, att_w): for ax_, att_w_ in zip(ax, att_w): ax_.imshow(att_w_.astype(np.float32), aspect="auto") ax_.set_xlabel("Input") ax_.set_ylabel("Output") ax_.xaxis.set_major_locator(MaxNLocator(integer=True)) ax_.yaxis.set_major_locator(MaxNLocator(integer=True)) fig.set_tight_layout({"rect": [0, 0.03, 1, 0.95]}) fig.savefig(output_dir / f"att_ws/{key}.png") fig.clf() if output_dict.get("prob") is not None: # Plot stop token prediction prob = output_dict["prob"].cpu().numpy() fig = plt.Figure() ax = fig.add_subplot(1, 1, 1) ax.plot(prob) ax.set_title(f"{key}") ax.set_xlabel("Output") ax.set_ylabel("Stop probability") ax.set_ylim(0, 1) ax.grid(which="both") fig.set_tight_layout(True) fig.savefig(output_dir / f"probs/{key}.png") fig.clf() if output_dict.get("wav") is not None: # TODO(kamo): Write scp sf.write( f"{output_dir}/wav/{key}.wav", output_dict["wav"].cpu().numpy(), text2speech.fs, "PCM_16", ) # remove files if those are not included in output dict if output_dict.get("feat_gen") is None: shutil.rmtree(output_dir / "norm") if output_dict.get("feat_gen_denorm") is None: shutil.rmtree(output_dir / "denorm") if output_dict.get("att_w") is None: shutil.rmtree(output_dir / "att_ws") if output_dict.get("duration") is None: shutil.rmtree(output_dir / "durations") if output_dict.get("focus_rate") is None: shutil.rmtree(output_dir / "focus_rates") if output_dict.get("prob") is None: shutil.rmtree(output_dir / "probs") if output_dict.get("wav") is None: shutil.rmtree(output_dir / "wav") def get_parser(): """Get argument parser.""" parser = config_argparse.ArgumentParser( description="TTS inference", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # Note(kamo): Use "_" instead of "-" as separator. # "-" is confusing if written in yaml. parser.add_argument( "--log_level", type=lambda x: x.upper(), default="INFO", choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), help="The verbose level of logging", ) parser.add_argument( "--output_dir", type=str, required=True, help="The path of output directory", ) parser.add_argument( "--ngpu", type=int, default=0, help="The number of gpus. 0 indicates CPU mode", ) parser.add_argument( "--seed", type=int, default=0, help="Random seed", ) parser.add_argument( "--dtype", default="float32", choices=["float16", "float32", "float64"], help="Data type", ) parser.add_argument( "--num_workers", type=int, default=1, help="The number of workers used for DataLoader", ) parser.add_argument( "--batch_size", type=int, default=1, help="The batch size for inference", ) group = parser.add_argument_group("Input data related") group.add_argument( "--data_path_and_name_and_type", type=str2triple_str, required=True, action="append", ) group.add_argument( "--key_file", type=str_or_none, ) group.add_argument( "--allow_variable_data_keys", type=str2bool, default=False, ) group = parser.add_argument_group("The model configuration related") group.add_argument( "--train_config", type=str, help="Training configuration file", ) group.add_argument( "--model_file", type=str, help="Model parameter file", ) group.add_argument( "--model_tag", type=str, help="Pretrained model tag. If specify this option, train_config and " "model_file will be overwritten", ) group = parser.add_argument_group("Decoding related") group.add_argument( "--maxlenratio", type=float, default=10.0, help="Maximum length ratio in decoding", ) group.add_argument( "--minlenratio", type=float, default=0.0, help="Minimum length ratio in decoding", ) group.add_argument( "--threshold", type=float, default=0.5, help="Threshold value in decoding", ) group.add_argument( "--use_att_constraint", type=str2bool, default=False, help="Whether to use attention constraint", ) group.add_argument( "--backward_window", type=int, default=1, help="Backward window value in attention constraint", ) group.add_argument( "--forward_window", type=int, default=3, help="Forward window value in attention constraint", ) group.add_argument( "--use_teacher_forcing", type=str2bool, default=False, help="Whether to use teacher forcing", ) parser.add_argument( "--speed_control_alpha", type=float, default=1.0, help="Alpha in FastSpeech to change the speed of generated speech", ) parser.add_argument( "--noise_scale", type=float, default=0.667, help="Noise scale parameter for the flow in vits", ) parser.add_argument( "--noise_scale_dur", type=float, default=0.8, help="Noise scale parameter for the stochastic duration predictor in vits", ) group.add_argument( "--always_fix_seed", type=str2bool, default=False, help="Whether to always fix seed", ) group = parser.add_argument_group("Vocoder related") group.add_argument( "--vocoder_config", type=str_or_none, help="Vocoder configuration file", ) group.add_argument( "--vocoder_file", type=str_or_none, help="Vocoder parameter file", ) group.add_argument( "--vocoder_tag", type=str, help="Pretrained vocoder tag. If specify this option, vocoder_config and " "vocoder_file will be overwritten", ) return parser def main(cmd=None): """Run TTS model inference.""" print(get_commandline_args(), file=sys.stderr) parser = get_parser() args = parser.parse_args(cmd) kwargs = vars(args) kwargs.pop("config", None) inference(**kwargs) if __name__ == "__main__": main()