tango / audioldm /__main__.py
deepanway's picture
add required files
6b448ad
#!/usr/bin/python3
import os
from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration
import argparse
CACHE_DIR = os.getenv(
"AUDIOLDM_CACHE_DIR",
os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
parser = argparse.ArgumentParser()
parser.add_argument(
"--mode",
type=str,
required=False,
default="generation",
help="generation: text-to-audio generation; transfer: style transfer",
choices=["generation", "transfer"]
)
parser.add_argument(
"-t",
"--text",
type=str,
required=False,
default="",
help="Text prompt to the model for audio generation",
)
parser.add_argument(
"-f",
"--file_path",
type=str,
required=False,
default=None,
help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio",
)
parser.add_argument(
"--transfer_strength",
type=float,
required=False,
default=0.5,
help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text",
)
parser.add_argument(
"-s",
"--save_path",
type=str,
required=False,
help="The path to save model output",
default="./output",
)
parser.add_argument(
"--model_name",
type=str,
required=False,
help="The checkpoint you gonna use",
default="audioldm-s-full",
choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2"]
)
parser.add_argument(
"-ckpt",
"--ckpt_path",
type=str,
required=False,
help="The path to the pretrained .ckpt model",
default=None,
)
parser.add_argument(
"-b",
"--batchsize",
type=int,
required=False,
default=1,
help="Generate how many samples at the same time",
)
parser.add_argument(
"--ddim_steps",
type=int,
required=False,
default=200,
help="The sampling step for DDIM",
)
parser.add_argument(
"-gs",
"--guidance_scale",
type=float,
required=False,
default=2.5,
help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
)
parser.add_argument(
"-dur",
"--duration",
type=float,
required=False,
default=10.0,
help="The duration of the samples",
)
parser.add_argument(
"-n",
"--n_candidate_gen_per_text",
type=int,
required=False,
default=3,
help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
)
parser.add_argument(
"--seed",
type=int,
required=False,
default=42,
help="Change this value (any integer number) will lead to a different generation result.",
)
args = parser.parse_args()
if(args.ckpt_path is not None):
print("Warning: ckpt_path has no effect after version 0.0.20.")
assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5"
mode = args.mode
if(mode == "generation" and args.file_path is not None):
mode = "generation_audio_to_audio"
if(len(args.text) > 0):
print("Warning: You have specified the --file_path. --text will be ignored")
args.text = ""
save_path = os.path.join(args.save_path, mode)
if(args.file_path is not None):
save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0]))
text = args.text
random_seed = args.seed
duration = args.duration
guidance_scale = args.guidance_scale
n_candidate_gen_per_text = args.n_candidate_gen_per_text
os.makedirs(save_path, exist_ok=True)
audioldm = build_model(model_name=args.model_name)
if(args.mode == "generation"):
waveform = text_to_audio(
audioldm,
text,
args.file_path,
random_seed,
duration=duration,
guidance_scale=guidance_scale,
ddim_steps=args.ddim_steps,
n_candidate_gen_per_text=n_candidate_gen_per_text,
batchsize=args.batchsize,
)
elif(args.mode == "transfer"):
assert args.file_path is not None
assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path
waveform = style_transfer(
audioldm,
text,
args.file_path,
args.transfer_strength,
random_seed,
duration=duration,
guidance_scale=guidance_scale,
ddim_steps=args.ddim_steps,
batchsize=args.batchsize,
)
waveform = waveform[:,None,:]
save_wave(waveform, save_path, name="%s_%s" % (get_time(), text))