Spaces:
Running
Running
File size: 4,437 Bytes
0d80816 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from argparse import ArgumentParser
import os
from models.tts.fastspeech2.fs2_inference import FastSpeech2Inference
from models.tts.vits.vits_inference import VitsInference
from models.tts.valle.valle_inference import VALLEInference
from utils.util import load_config
import torch
def build_inference(args, cfg):
supported_inference = {
"FastSpeech2": FastSpeech2Inference,
"VITS": VitsInference,
"VALLE": VALLEInference,
}
inference_class = supported_inference[cfg.model_type]
inference = inference_class(args, cfg)
return inference
def cuda_relevant(deterministic=False):
torch.cuda.empty_cache()
# TF32 on Ampere and above
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
# Deterministic
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
torch.use_deterministic_algorithms(deterministic)
def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
required=True,
help="JSON/YAML file for configurations.",
)
parser.add_argument(
"--dataset",
type=str,
help="convert from the source data",
default=None,
)
parser.add_argument(
"--testing_set",
type=str,
help="train, test, golden_test",
default="test",
)
parser.add_argument(
"--test_list_file",
type=str,
help="convert from the test list file",
default=None,
)
parser.add_argument(
"--speaker_name",
type=str,
default=None,
help="speaker name for multi-speaker synthesis, for single-sentence mode only",
)
parser.add_argument(
"--text",
help="Text to be synthesized.",
type=str,
default="",
)
parser.add_argument(
"--vocoder_dir",
type=str,
default=None,
help="Vocoder checkpoint directory. Searching behavior is the same as "
"the acoustics one.",
)
parser.add_argument(
"--acoustics_dir",
type=str,
default=None,
help="Acoustic model checkpoint directory. If a directory is given, "
"search for the latest checkpoint dir in the directory. If a specific "
"checkpoint dir is given, directly load the checkpoint.",
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help="Acoustic model checkpoint directory. If a directory is given, "
"search for the latest checkpoint dir in the directory. If a specific "
"checkpoint dir is given, directly load the checkpoint.",
)
parser.add_argument(
"--mode",
type=str,
choices=["batch", "single"],
required=True,
help="Synthesize a whole dataset or a single sentence",
)
parser.add_argument(
"--log_level",
type=str,
default="warning",
help="Logging level. Default: warning",
)
parser.add_argument(
"--pitch_control",
type=float,
default=1.0,
help="control the pitch of the whole utterance, larger value for higher pitch",
)
parser.add_argument(
"--energy_control",
type=float,
default=1.0,
help="control the energy of the whole utterance, larger value for larger volume",
)
parser.add_argument(
"--duration_control",
type=float,
default=1.0,
help="control the speed of the whole utterance, larger value for slower speaking rate",
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Output dir for saving generated results",
)
return parser
def main():
# Parse arguments
parser = build_parser()
VALLEInference.add_arguments(parser)
args = parser.parse_args()
# Parse config
cfg = load_config(args.config)
# CUDA settings
cuda_relevant()
# Build inference
inferencer = build_inference(args, cfg)
# Run inference
inferencer.inference()
if __name__ == "__main__":
main()
|