|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
from argparse import ArgumentParser |
|
import os |
|
|
|
from models.tta.ldm.audioldm_inference import AudioLDMInference |
|
from utils.util import save_config, load_model_config, load_config |
|
import numpy as np |
|
import torch |
|
|
|
|
|
def build_inference(args, cfg): |
|
supported_inference = { |
|
"AudioLDM": AudioLDMInference, |
|
} |
|
|
|
inference_class = supported_inference[cfg.model_type] |
|
inference = inference_class(args, cfg) |
|
return inference |
|
|
|
|
|
def build_parser(): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--config", |
|
type=str, |
|
required=True, |
|
help="JSON/YAML file for configurations.", |
|
) |
|
parser.add_argument( |
|
"--text", |
|
help="Text to be synthesized", |
|
type=str, |
|
default="Text to be synthesized.", |
|
) |
|
parser.add_argument( |
|
"--checkpoint_path", |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"--vocoder_path", type=str, help="Checkpoint path of the vocoder" |
|
) |
|
parser.add_argument( |
|
"--vocoder_config_path", type=str, help="Config path of the vocoder" |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default=None, |
|
help="Output dir for saving generated results", |
|
) |
|
parser.add_argument( |
|
"--num_steps", |
|
type=int, |
|
default=200, |
|
help="The total number of denosing steps", |
|
) |
|
parser.add_argument( |
|
"--guidance_scale", |
|
type=float, |
|
default=4.0, |
|
help="The scale of classifer free guidance", |
|
) |
|
parser.add_argument("--local_rank", default=-1, type=int) |
|
return parser |
|
|
|
|
|
def main(): |
|
|
|
args = build_parser().parse_args() |
|
|
|
|
|
|
|
cfg = load_config(args.config) |
|
if torch.cuda.is_available(): |
|
args.local_rank = torch.device("cuda") |
|
else: |
|
args.local_rank = torch.device("cpu") |
|
print("args: ", args) |
|
|
|
|
|
inferencer = build_inference(args, cfg) |
|
|
|
|
|
inferencer.inference() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|