File size: 4,686 Bytes
f1069cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
#!/usr/bin/python3
import os
from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration
import argparse
CACHE_DIR = os.getenv(
"AUDIOLDM_CACHE_DIR",
os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
parser = argparse.ArgumentParser()
parser.add_argument(
"--mode",
type=str,
required=False,
default="generation",
help="generation: text-to-audio generation; transfer: style transfer",
choices=["generation", "transfer"]
)
parser.add_argument(
"-t",
"--text",
type=str,
required=False,
default="",
help="Text prompt to the model for audio generation",
)
parser.add_argument(
"-f",
"--file_path",
type=str,
required=False,
default=None,
help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio",
)
parser.add_argument(
"--transfer_strength",
type=float,
required=False,
default=0.5,
help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text",
)
parser.add_argument(
"-s",
"--save_path",
type=str,
required=False,
help="The path to save model output",
default="./output",
)
parser.add_argument(
"--model_name",
type=str,
required=False,
help="The checkpoint you gonna use",
default="audioldm-s-full",
choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2"]
)
parser.add_argument(
"-ckpt",
"--ckpt_path",
type=str,
required=False,
help="The path to the pretrained .ckpt model",
default=None,
)
parser.add_argument(
"-b",
"--batchsize",
type=int,
required=False,
default=1,
help="Generate how many samples at the same time",
)
parser.add_argument(
"--ddim_steps",
type=int,
required=False,
default=200,
help="The sampling step for DDIM",
)
parser.add_argument(
"-gs",
"--guidance_scale",
type=float,
required=False,
default=2.5,
help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
)
parser.add_argument(
"-dur",
"--duration",
type=float,
required=False,
default=10.0,
help="The duration of the samples",
)
parser.add_argument(
"-n",
"--n_candidate_gen_per_text",
type=int,
required=False,
default=3,
help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
)
parser.add_argument(
"--seed",
type=int,
required=False,
default=42,
help="Change this value (any integer number) will lead to a different generation result.",
)
args = parser.parse_args()
if(args.ckpt_path is not None):
print("Warning: ckpt_path has no effect after version 0.0.20.")
assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5"
mode = args.mode
if(mode == "generation" and args.file_path is not None):
mode = "generation_audio_to_audio"
if(len(args.text) > 0):
print("Warning: You have specified the --file_path. --text will be ignored")
args.text = ""
save_path = os.path.join(args.save_path, mode)
if(args.file_path is not None):
save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0]))
text = args.text
random_seed = args.seed
duration = args.duration
guidance_scale = args.guidance_scale
n_candidate_gen_per_text = args.n_candidate_gen_per_text
os.makedirs(save_path, exist_ok=True)
audioldm = build_model(model_name=args.model_name)
if(args.mode == "generation"):
waveform = text_to_audio(
audioldm,
text,
args.file_path,
random_seed,
duration=duration,
guidance_scale=guidance_scale,
ddim_steps=args.ddim_steps,
n_candidate_gen_per_text=n_candidate_gen_per_text,
batchsize=args.batchsize,
)
elif(args.mode == "transfer"):
assert args.file_path is not None
assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path
waveform = style_transfer(
audioldm,
text,
args.file_path,
args.transfer_strength,
random_seed,
duration=duration,
guidance_scale=guidance_scale,
ddim_steps=args.ddim_steps,
batchsize=args.batchsize,
)
waveform = waveform[:,None,:]
save_wave(waveform, save_path, name="%s_%s" % (get_time(), text))
|