Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
"""Script to run the inference of text-to-speeech model.""" | |
import argparse | |
import logging | |
import shutil | |
import sys | |
import time | |
from pathlib import Path | |
from typing import Any, Dict, Optional, Sequence, Tuple, Union | |
import numpy as np | |
import _soundfile.py as sf | |
import torch | |
from packaging.version import parse as V | |
from typeguard import typechecked | |
from espnet2.fileio.npy_scp import NpyScpWriter | |
from espnet2.gan_tts.vits import VITS | |
from espnet2.tasks.tts import TTSTask | |
from espnet2.torch_utils.device_funcs import to_device | |
from espnet2.torch_utils.set_all_random_seed import set_all_random_seed | |
from espnet2.tts.fastspeech import FastSpeech | |
from espnet2.tts.fastspeech2 import FastSpeech2 | |
from espnet2.tts.tacotron2 import Tacotron2 | |
from espnet2.tts.transformer import Transformer | |
from espnet2.tts.utils import DurationCalculator | |
from espnet2.utils import config_argparse | |
from espnet2.utils.types import str2bool, str2triple_str, str_or_none | |
from espnet.utils.cli_utils import get_commandline_args | |
class Text2Speech: | |
"""Text2Speech class. | |
Examples: | |
>>> from espnet2.bin.tts_inference import Text2Speech | |
>>> # Case 1: Load the local model and use Griffin-Lim vocoder | |
>>> text2speech = Text2Speech( | |
>>> train_config="/path/to/config.yml", | |
>>> model_file="/path/to/model.pth", | |
>>> ) | |
>>> # Case 2: Load the local model and the pretrained vocoder | |
>>> text2speech = Text2Speech.from_pretrained( | |
>>> train_config="/path/to/config.yml", | |
>>> model_file="/path/to/model.pth", | |
>>> vocoder_tag="kan-bayashi/ljspeech_tacotron2", | |
>>> ) | |
>>> # Case 3: Load the pretrained model and use Griffin-Lim vocoder | |
>>> text2speech = Text2Speech.from_pretrained( | |
>>> model_tag="kan-bayashi/ljspeech_tacotron2", | |
>>> ) | |
>>> # Case 4: Load the pretrained model and the pretrained vocoder | |
>>> text2speech = Text2Speech.from_pretrained( | |
>>> model_tag="kan-bayashi/ljspeech_tacotron2", | |
>>> vocoder_tag="parallel_wavegan/ljspeech_parallel_wavegan.v1", | |
>>> ) | |
>>> # Run inference and save as wav file | |
>>> import soundfile as sf | |
>>> wav = text2speech("Hello, World")["wav"] | |
>>> sf.write("out.wav", wav.numpy(), text2speech.fs, "PCM_16") | |
""" | |
def __init__( | |
self, | |
train_config: Union[Path, str, None] = None, | |
model_file: Union[Path, str, None] = None, | |
threshold: float = 0.5, | |
minlenratio: float = 0.0, | |
maxlenratio: float = 10.0, | |
use_teacher_forcing: bool = False, | |
use_att_constraint: bool = False, | |
backward_window: int = 1, | |
forward_window: int = 3, | |
speed_control_alpha: float = 1.0, | |
noise_scale: float = 0.667, | |
noise_scale_dur: float = 0.8, | |
vocoder_config: Union[Path, str, None] = None, | |
vocoder_file: Union[Path, str, None] = None, | |
dtype: str = "float32", | |
device: str = "cpu", | |
seed: int = 777, | |
always_fix_seed: bool = False, | |
prefer_normalized_feats: bool = False, | |
): | |
"""Initialize Text2Speech module.""" | |
# setup model | |
model, train_args = TTSTask.build_model_from_file( | |
train_config, model_file, device | |
) | |
model.to(dtype=getattr(torch, dtype)).eval() | |
self.device = device | |
self.dtype = dtype | |
self.train_args = train_args | |
self.model = model | |
self.tts = model.tts | |
self.normalize = model.normalize | |
self.feats_extract = model.feats_extract | |
self.duration_calculator = DurationCalculator() | |
self.preprocess_fn = TTSTask.build_preprocess_fn(train_args, False) | |
self.use_teacher_forcing = use_teacher_forcing | |
self.seed = seed | |
self.always_fix_seed = always_fix_seed | |
self.vocoder = None | |
self.prefer_normalized_feats = prefer_normalized_feats | |
if self.tts.require_vocoder: | |
vocoder = TTSTask.build_vocoder_from_file( | |
vocoder_config, vocoder_file, model, device | |
) | |
if isinstance(vocoder, torch.nn.Module): | |
vocoder.to(dtype=getattr(torch, dtype)).eval() | |
self.vocoder = vocoder | |
logging.info(f"Extractor:\n{self.feats_extract}") | |
logging.info(f"Normalizer:\n{self.normalize}") | |
logging.info(f"TTS:\n{self.tts}") | |
if self.vocoder is not None: | |
logging.info(f"Vocoder:\n{self.vocoder}") | |
# setup decoding config | |
decode_conf = {} | |
decode_conf.update(use_teacher_forcing=use_teacher_forcing) | |
if isinstance(self.tts, (Tacotron2, Transformer)): | |
decode_conf.update( | |
threshold=threshold, | |
maxlenratio=maxlenratio, | |
minlenratio=minlenratio, | |
) | |
if isinstance(self.tts, Tacotron2): | |
decode_conf.update( | |
use_att_constraint=use_att_constraint, | |
forward_window=forward_window, | |
backward_window=backward_window, | |
) | |
if isinstance(self.tts, (FastSpeech, FastSpeech2, VITS)): | |
decode_conf.update(alpha=speed_control_alpha) | |
if isinstance(self.tts, VITS): | |
decode_conf.update( | |
noise_scale=noise_scale, | |
noise_scale_dur=noise_scale_dur, | |
) | |
self.decode_conf = decode_conf | |
def __call__( | |
self, | |
text: Union[str, torch.Tensor, np.ndarray], | |
speech: Union[torch.Tensor, np.ndarray, None] = None, | |
durations: Union[torch.Tensor, np.ndarray, None] = None, | |
spembs: Union[torch.Tensor, np.ndarray, None] = None, | |
sids: Union[torch.Tensor, np.ndarray, None] = None, | |
lids: Union[torch.Tensor, np.ndarray, None] = None, | |
decode_conf: Optional[Dict[str, Any]] = None, | |
) -> Dict[str, torch.Tensor]: | |
"""Run text-to-speech.""" | |
# check inputs | |
if self.use_speech and speech is None: | |
raise RuntimeError("Missing required argument: 'speech'") | |
if self.use_sids and sids is None: | |
raise RuntimeError("Missing required argument: 'sids'") | |
if self.use_lids and lids is None: | |
raise RuntimeError("Missing required argument: 'lids'") | |
if self.use_spembs and spembs is None: | |
raise RuntimeError("Missing required argument: 'spembs'") | |
# prepare batch | |
if isinstance(text, str): | |
text = self.preprocess_fn("<dummy>", dict(text=text))["text"] | |
batch = dict(text=text) | |
if speech is not None: | |
batch.update(speech=speech) | |
if durations is not None: | |
batch.update(durations=durations) | |
if spembs is not None: | |
batch.update(spembs=spembs) | |
if sids is not None: | |
batch.update(sids=sids) | |
if lids is not None: | |
batch.update(lids=lids) | |
batch = to_device(batch, self.device) | |
# overwrite the decode configs if provided | |
cfg = self.decode_conf | |
if decode_conf is not None: | |
cfg = self.decode_conf.copy() | |
cfg.update(decode_conf) | |
# inference | |
if self.always_fix_seed: | |
set_all_random_seed(self.seed) | |
output_dict = self.model.inference(**batch, **cfg) | |
# calculate additional metrics | |
if output_dict.get("att_w") is not None: | |
duration, focus_rate = self.duration_calculator(output_dict["att_w"]) | |
output_dict.update(duration=duration, focus_rate=focus_rate) | |
# apply vocoder (mel-to-wav) | |
if self.vocoder is not None: | |
if ( | |
self.prefer_normalized_feats | |
or output_dict.get("feat_gen_denorm") is None | |
): | |
input_feat = output_dict["feat_gen"] | |
else: | |
input_feat = output_dict["feat_gen_denorm"] | |
wav = self.vocoder(input_feat) | |
output_dict.update(wav=wav) | |
return output_dict | |
def fs(self) -> Optional[int]: | |
"""Return sampling rate.""" | |
if hasattr(self.vocoder, "fs"): | |
return self.vocoder.fs | |
elif hasattr(self.tts, "fs"): | |
return self.tts.fs | |
else: | |
return None | |
def use_speech(self) -> bool: | |
"""Return speech is needed or not in the inference.""" | |
return self.use_teacher_forcing or getattr(self.tts, "use_gst", False) | |
def use_sids(self) -> bool: | |
"""Return sid is needed or not in the inference.""" | |
return self.tts.spks is not None | |
def use_lids(self) -> bool: | |
"""Return sid is needed or not in the inference.""" | |
return self.tts.langs is not None | |
def use_spembs(self) -> bool: | |
"""Return spemb is needed or not in the inference.""" | |
return self.tts.spk_embed_dim is not None | |
def from_pretrained( | |
model_tag: Optional[str] = None, | |
vocoder_tag: Optional[str] = None, | |
**kwargs: Optional[Any], | |
): | |
"""Build Text2Speech instance from the pretrained model. | |
Args: | |
model_tag (Optional[str]): Model tag of the pretrained models. | |
Currently, the tags of espnet_model_zoo are supported. | |
vocoder_tag (Optional[str]): Vocoder tag of the pretrained vocoders. | |
Currently, the tags of parallel_wavegan are supported, which should | |
start with the prefix "parallel_wavegan/". | |
Returns: | |
Text2Speech: Text2Speech instance. | |
""" | |
if model_tag is not None: | |
try: | |
from espnet_model_zoo.downloader import ModelDownloader | |
except ImportError: | |
logging.error( | |
"`espnet_model_zoo` is not installed. " | |
"Please install via `pip install -U espnet_model_zoo`." | |
) | |
raise | |
d = ModelDownloader() | |
kwargs.update(**d.download_and_unpack(model_tag)) | |
if vocoder_tag is not None: | |
if vocoder_tag.startswith("parallel_wavegan/"): | |
try: | |
from parallel_wavegan.utils import download_pretrained_model | |
except ImportError: | |
logging.error( | |
"`parallel_wavegan` is not installed. " | |
"Please install via `pip install -U parallel_wavegan`." | |
) | |
raise | |
from parallel_wavegan import __version__ | |
# NOTE(kan-bayashi): Filelock download is supported from 0.5.2 | |
assert V(__version__) > V("0.5.1"), ( | |
"Please install the latest parallel_wavegan " | |
"via `pip install -U parallel_wavegan`." | |
) | |
vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "") | |
vocoder_file = download_pretrained_model(vocoder_tag) | |
vocoder_config = Path(vocoder_file).parent / "config.yml" | |
kwargs.update(vocoder_config=vocoder_config, vocoder_file=vocoder_file) | |
else: | |
raise ValueError(f"{vocoder_tag} is unsupported format.") | |
return Text2Speech(**kwargs) | |
def inference( | |
output_dir: Union[Path, str], | |
batch_size: int, | |
dtype: str, | |
ngpu: int, | |
seed: int, | |
num_workers: int, | |
log_level: Union[int, str], | |
data_path_and_name_and_type: Sequence[Tuple[str, str, str]], | |
key_file: Optional[str], | |
train_config: Optional[str], | |
model_file: Optional[str], | |
model_tag: Optional[str], | |
threshold: float, | |
minlenratio: float, | |
maxlenratio: float, | |
use_teacher_forcing: bool, | |
use_att_constraint: bool, | |
backward_window: int, | |
forward_window: int, | |
speed_control_alpha: float, | |
noise_scale: float, | |
noise_scale_dur: float, | |
always_fix_seed: bool, | |
allow_variable_data_keys: bool, | |
vocoder_config: Optional[str], | |
vocoder_file: Optional[str], | |
vocoder_tag: Optional[str], | |
): | |
"""Run text-to-speech inference.""" | |
if batch_size > 1: | |
raise NotImplementedError("batch decoding is not implemented") | |
if ngpu > 1: | |
raise NotImplementedError("only single GPU decoding is supported") | |
logging.basicConfig( | |
level=log_level, | |
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
) | |
if ngpu >= 1: | |
device = "cuda" | |
else: | |
device = "cpu" | |
# 1. Set random-seed | |
set_all_random_seed(seed) | |
# 2. Build model | |
text2speech_kwargs = dict( | |
train_config=train_config, | |
model_file=model_file, | |
threshold=threshold, | |
maxlenratio=maxlenratio, | |
minlenratio=minlenratio, | |
use_teacher_forcing=use_teacher_forcing, | |
use_att_constraint=use_att_constraint, | |
backward_window=backward_window, | |
forward_window=forward_window, | |
speed_control_alpha=speed_control_alpha, | |
noise_scale=noise_scale, | |
noise_scale_dur=noise_scale_dur, | |
vocoder_config=vocoder_config, | |
vocoder_file=vocoder_file, | |
dtype=dtype, | |
device=device, | |
seed=seed, | |
always_fix_seed=always_fix_seed, | |
) | |
text2speech = Text2Speech.from_pretrained( | |
model_tag=model_tag, | |
vocoder_tag=vocoder_tag, | |
**text2speech_kwargs, | |
) | |
# 3. Build data-iterator | |
if not text2speech.use_speech: | |
data_path_and_name_and_type = list( | |
filter(lambda x: x[1] != "speech", data_path_and_name_and_type) | |
) | |
loader = TTSTask.build_streaming_iterator( | |
data_path_and_name_and_type, | |
dtype=dtype, | |
batch_size=batch_size, | |
key_file=key_file, | |
num_workers=num_workers, | |
preprocess_fn=TTSTask.build_preprocess_fn(text2speech.train_args, False), | |
collate_fn=TTSTask.build_collate_fn(text2speech.train_args, False), | |
allow_variable_data_keys=allow_variable_data_keys, | |
inference=True, | |
) | |
# 4. Start for-loop | |
output_dir = Path(output_dir) | |
(output_dir / "norm").mkdir(parents=True, exist_ok=True) | |
(output_dir / "denorm").mkdir(parents=True, exist_ok=True) | |
(output_dir / "speech_shape").mkdir(parents=True, exist_ok=True) | |
(output_dir / "wav").mkdir(parents=True, exist_ok=True) | |
(output_dir / "att_ws").mkdir(parents=True, exist_ok=True) | |
(output_dir / "probs").mkdir(parents=True, exist_ok=True) | |
(output_dir / "durations").mkdir(parents=True, exist_ok=True) | |
(output_dir / "focus_rates").mkdir(parents=True, exist_ok=True) | |
# Lazy load to avoid the backend error | |
import matplotlib | |
matplotlib.use("Agg") | |
import matplotlib.pyplot as plt | |
from matplotlib.ticker import MaxNLocator | |
with NpyScpWriter( | |
output_dir / "norm", | |
output_dir / "norm/feats.scp", | |
) as norm_writer, NpyScpWriter( | |
output_dir / "denorm", output_dir / "denorm/feats.scp" | |
) as denorm_writer, open( | |
output_dir / "speech_shape/speech_shape", "w" | |
) as shape_writer, open( | |
output_dir / "durations/durations", "w" | |
) as duration_writer, open( | |
output_dir / "focus_rates/focus_rates", "w" | |
) as focus_rate_writer: | |
for idx, (keys, batch) in enumerate(loader, 1): | |
assert isinstance(batch, dict), type(batch) | |
assert all(isinstance(s, str) for s in keys), keys | |
_bs = len(next(iter(batch.values()))) | |
assert _bs == 1, _bs | |
# Change to single sequence and remove *_length | |
# because inference() requires 1-seq, not mini-batch. | |
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} | |
start_time = time.perf_counter() | |
output_dict = text2speech(**batch) | |
key = keys[0] | |
insize = next(iter(batch.values())).size(0) + 1 | |
if output_dict.get("feat_gen") is not None: | |
# standard text2mel model case | |
feat_gen = output_dict["feat_gen"] | |
logging.info( | |
"inference speed = {:.1f} frames / sec.".format( | |
int(feat_gen.size(0)) / (time.perf_counter() - start_time) | |
) | |
) | |
logging.info(f"{key} (size:{insize}->{feat_gen.size(0)})") | |
if feat_gen.size(0) == insize * maxlenratio: | |
logging.warning(f"output length reaches maximum length ({key}).") | |
norm_writer[key] = output_dict["feat_gen"].cpu().numpy() | |
shape_writer.write( | |
f"{key} " + ",".join(map(str, output_dict["feat_gen"].shape)) + "\n" | |
) | |
if output_dict.get("feat_gen_denorm") is not None: | |
denorm_writer[key] = output_dict["feat_gen_denorm"].cpu().numpy() | |
else: | |
# end-to-end text2wav model case | |
wav = output_dict["wav"] | |
logging.info( | |
"inference speed = {:.1f} points / sec.".format( | |
int(wav.size(0)) / (time.perf_counter() - start_time) | |
) | |
) | |
logging.info(f"{key} (size:{insize}->{wav.size(0)})") | |
if output_dict.get("duration") is not None: | |
# Save duration and fucus rates | |
duration_writer.write( | |
f"{key} " | |
+ " ".join(map(str, output_dict["duration"].long().cpu().numpy())) | |
+ "\n" | |
) | |
if output_dict.get("focus_rate") is not None: | |
focus_rate_writer.write( | |
f"{key} {float(output_dict['focus_rate']):.5f}\n" | |
) | |
if output_dict.get("att_w") is not None: | |
# Plot attention weight | |
att_w = output_dict["att_w"].cpu().numpy() | |
if att_w.ndim == 2: | |
att_w = att_w[None][None] | |
elif att_w.ndim != 4: | |
raise RuntimeError(f"Must be 2 or 4 dimension: {att_w.ndim}") | |
w, h = plt.figaspect(att_w.shape[0] / att_w.shape[1]) | |
fig = plt.Figure( | |
figsize=( | |
w * 1.3 * min(att_w.shape[0], 2.5), | |
h * 1.3 * min(att_w.shape[1], 2.5), | |
) | |
) | |
fig.suptitle(f"{key}") | |
axes = fig.subplots(att_w.shape[0], att_w.shape[1]) | |
if len(att_w) == 1: | |
axes = [[axes]] | |
for ax, att_w in zip(axes, att_w): | |
for ax_, att_w_ in zip(ax, att_w): | |
ax_.imshow(att_w_.astype(np.float32), aspect="auto") | |
ax_.set_xlabel("Input") | |
ax_.set_ylabel("Output") | |
ax_.xaxis.set_major_locator(MaxNLocator(integer=True)) | |
ax_.yaxis.set_major_locator(MaxNLocator(integer=True)) | |
fig.set_tight_layout({"rect": [0, 0.03, 1, 0.95]}) | |
fig.savefig(output_dir / f"att_ws/{key}.png") | |
fig.clf() | |
if output_dict.get("prob") is not None: | |
# Plot stop token prediction | |
prob = output_dict["prob"].cpu().numpy() | |
fig = plt.Figure() | |
ax = fig.add_subplot(1, 1, 1) | |
ax.plot(prob) | |
ax.set_title(f"{key}") | |
ax.set_xlabel("Output") | |
ax.set_ylabel("Stop probability") | |
ax.set_ylim(0, 1) | |
ax.grid(which="both") | |
fig.set_tight_layout(True) | |
fig.savefig(output_dir / f"probs/{key}.png") | |
fig.clf() | |
if output_dict.get("wav") is not None: | |
# TODO(kamo): Write scp | |
sf.write( | |
f"{output_dir}/wav/{key}.wav", | |
output_dict["wav"].cpu().numpy(), | |
text2speech.fs, | |
"PCM_16", | |
) | |
# remove files if those are not included in output dict | |
if output_dict.get("feat_gen") is None: | |
shutil.rmtree(output_dir / "norm") | |
if output_dict.get("feat_gen_denorm") is None: | |
shutil.rmtree(output_dir / "denorm") | |
if output_dict.get("att_w") is None: | |
shutil.rmtree(output_dir / "att_ws") | |
if output_dict.get("duration") is None: | |
shutil.rmtree(output_dir / "durations") | |
if output_dict.get("focus_rate") is None: | |
shutil.rmtree(output_dir / "focus_rates") | |
if output_dict.get("prob") is None: | |
shutil.rmtree(output_dir / "probs") | |
if output_dict.get("wav") is None: | |
shutil.rmtree(output_dir / "wav") | |
def get_parser(): | |
"""Get argument parser.""" | |
parser = config_argparse.ArgumentParser( | |
description="TTS inference", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
# Note(kamo): Use "_" instead of "-" as separator. | |
# "-" is confusing if written in yaml. | |
parser.add_argument( | |
"--log_level", | |
type=lambda x: x.upper(), | |
default="INFO", | |
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), | |
help="The verbose level of logging", | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
required=True, | |
help="The path of output directory", | |
) | |
parser.add_argument( | |
"--ngpu", | |
type=int, | |
default=0, | |
help="The number of gpus. 0 indicates CPU mode", | |
) | |
parser.add_argument( | |
"--seed", | |
type=int, | |
default=0, | |
help="Random seed", | |
) | |
parser.add_argument( | |
"--dtype", | |
default="float32", | |
choices=["float16", "float32", "float64"], | |
help="Data type", | |
) | |
parser.add_argument( | |
"--num_workers", | |
type=int, | |
default=1, | |
help="The number of workers used for DataLoader", | |
) | |
parser.add_argument( | |
"--batch_size", | |
type=int, | |
default=1, | |
help="The batch size for inference", | |
) | |
group = parser.add_argument_group("Input data related") | |
group.add_argument( | |
"--data_path_and_name_and_type", | |
type=str2triple_str, | |
required=True, | |
action="append", | |
) | |
group.add_argument( | |
"--key_file", | |
type=str_or_none, | |
) | |
group.add_argument( | |
"--allow_variable_data_keys", | |
type=str2bool, | |
default=False, | |
) | |
group = parser.add_argument_group("The model configuration related") | |
group.add_argument( | |
"--train_config", | |
type=str, | |
help="Training configuration file", | |
) | |
group.add_argument( | |
"--model_file", | |
type=str, | |
help="Model parameter file", | |
) | |
group.add_argument( | |
"--model_tag", | |
type=str, | |
help="Pretrained model tag. If specify this option, train_config and " | |
"model_file will be overwritten", | |
) | |
group = parser.add_argument_group("Decoding related") | |
group.add_argument( | |
"--maxlenratio", | |
type=float, | |
default=10.0, | |
help="Maximum length ratio in decoding", | |
) | |
group.add_argument( | |
"--minlenratio", | |
type=float, | |
default=0.0, | |
help="Minimum length ratio in decoding", | |
) | |
group.add_argument( | |
"--threshold", | |
type=float, | |
default=0.5, | |
help="Threshold value in decoding", | |
) | |
group.add_argument( | |
"--use_att_constraint", | |
type=str2bool, | |
default=False, | |
help="Whether to use attention constraint", | |
) | |
group.add_argument( | |
"--backward_window", | |
type=int, | |
default=1, | |
help="Backward window value in attention constraint", | |
) | |
group.add_argument( | |
"--forward_window", | |
type=int, | |
default=3, | |
help="Forward window value in attention constraint", | |
) | |
group.add_argument( | |
"--use_teacher_forcing", | |
type=str2bool, | |
default=False, | |
help="Whether to use teacher forcing", | |
) | |
parser.add_argument( | |
"--speed_control_alpha", | |
type=float, | |
default=1.0, | |
help="Alpha in FastSpeech to change the speed of generated speech", | |
) | |
parser.add_argument( | |
"--noise_scale", | |
type=float, | |
default=0.667, | |
help="Noise scale parameter for the flow in vits", | |
) | |
parser.add_argument( | |
"--noise_scale_dur", | |
type=float, | |
default=0.8, | |
help="Noise scale parameter for the stochastic duration predictor in vits", | |
) | |
group.add_argument( | |
"--always_fix_seed", | |
type=str2bool, | |
default=False, | |
help="Whether to always fix seed", | |
) | |
group = parser.add_argument_group("Vocoder related") | |
group.add_argument( | |
"--vocoder_config", | |
type=str_or_none, | |
help="Vocoder configuration file", | |
) | |
group.add_argument( | |
"--vocoder_file", | |
type=str_or_none, | |
help="Vocoder parameter file", | |
) | |
group.add_argument( | |
"--vocoder_tag", | |
type=str, | |
help="Pretrained vocoder tag. If specify this option, vocoder_config and " | |
"vocoder_file will be overwritten", | |
) | |
return parser | |
def main(cmd=None): | |
"""Run TTS model inference.""" | |
print(get_commandline_args(), file=sys.stderr) | |
parser = get_parser() | |
args = parser.parse_args(cmd) | |
kwargs = vars(args) | |
kwargs.pop("config", None) | |
inference(**kwargs) | |
if __name__ == "__main__": | |
main() | |