Spaces:
Sleeping
Sleeping
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
from argparse import ArgumentParser | |
import os | |
from models.tts.fastspeech2.fs2_inference import FastSpeech2Inference | |
from models.tts.vits.vits_inference import VitsInference | |
from models.tts.valle.valle_inference import VALLEInference | |
from utils.util import load_config | |
import torch | |
def build_inference(args, cfg): | |
supported_inference = { | |
"FastSpeech2": FastSpeech2Inference, | |
"VITS": VitsInference, | |
"VALLE": VALLEInference, | |
} | |
inference_class = supported_inference[cfg.model_type] | |
inference = inference_class(args, cfg) | |
return inference | |
def cuda_relevant(deterministic=False): | |
torch.cuda.empty_cache() | |
# TF32 on Ampere and above | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.allow_tf32 = True | |
# Deterministic | |
torch.backends.cudnn.deterministic = deterministic | |
torch.backends.cudnn.benchmark = not deterministic | |
torch.use_deterministic_algorithms(deterministic) | |
def build_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--config", | |
type=str, | |
required=True, | |
help="JSON/YAML file for configurations.", | |
) | |
parser.add_argument( | |
"--dataset", | |
type=str, | |
help="convert from the source data", | |
default=None, | |
) | |
parser.add_argument( | |
"--testing_set", | |
type=str, | |
help="train, test, golden_test", | |
default="test", | |
) | |
parser.add_argument( | |
"--test_list_file", | |
type=str, | |
help="convert from the test list file", | |
default=None, | |
) | |
parser.add_argument( | |
"--speaker_name", | |
type=str, | |
default=None, | |
help="speaker name for multi-speaker synthesis, for single-sentence mode only", | |
) | |
parser.add_argument( | |
"--text", | |
help="Text to be synthesized.", | |
type=str, | |
default="", | |
) | |
parser.add_argument( | |
"--vocoder_dir", | |
type=str, | |
default=None, | |
help="Vocoder checkpoint directory. Searching behavior is the same as " | |
"the acoustics one.", | |
) | |
parser.add_argument( | |
"--acoustics_dir", | |
type=str, | |
default=None, | |
help="Acoustic model checkpoint directory. If a directory is given, " | |
"search for the latest checkpoint dir in the directory. If a specific " | |
"checkpoint dir is given, directly load the checkpoint.", | |
) | |
parser.add_argument( | |
"--checkpoint_path", | |
type=str, | |
default=None, | |
help="Acoustic model checkpoint directory. If a directory is given, " | |
"search for the latest checkpoint dir in the directory. If a specific " | |
"checkpoint dir is given, directly load the checkpoint.", | |
) | |
parser.add_argument( | |
"--mode", | |
type=str, | |
choices=["batch", "single"], | |
required=True, | |
help="Synthesize a whole dataset or a single sentence", | |
) | |
parser.add_argument( | |
"--log_level", | |
type=str, | |
default="warning", | |
help="Logging level. Default: warning", | |
) | |
parser.add_argument( | |
"--pitch_control", | |
type=float, | |
default=1.0, | |
help="control the pitch of the whole utterance, larger value for higher pitch", | |
) | |
parser.add_argument( | |
"--energy_control", | |
type=float, | |
default=1.0, | |
help="control the energy of the whole utterance, larger value for larger volume", | |
) | |
parser.add_argument( | |
"--duration_control", | |
type=float, | |
default=1.0, | |
help="control the speed of the whole utterance, larger value for slower speaking rate", | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default=None, | |
help="Output dir for saving generated results", | |
) | |
return parser | |
def main(): | |
# Parse arguments | |
parser = build_parser() | |
VALLEInference.add_arguments(parser) | |
args = parser.parse_args() | |
# Parse config | |
cfg = load_config(args.config) | |
# CUDA settings | |
cuda_relevant() | |
# Build inference | |
inferencer = build_inference(args, cfg) | |
# Run inference | |
inferencer.inference() | |
if __name__ == "__main__": | |
main() | |