#!/usr/bin/python3 import os from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration import argparse CACHE_DIR = os.getenv( "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm")) parser = argparse.ArgumentParser() parser.add_argument( "--mode", type=str, required=False, default="generation", help="generation: text-to-audio generation; transfer: style transfer", choices=["generation", "transfer"] ) parser.add_argument( "-t", "--text", type=str, required=False, default="", help="Text prompt to the model for audio generation", ) parser.add_argument( "-f", "--file_path", type=str, required=False, default=None, help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio", ) parser.add_argument( "--transfer_strength", type=float, required=False, default=0.5, help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text", ) parser.add_argument( "-s", "--save_path", type=str, required=False, help="The path to save model output", default="./output", ) parser.add_argument( "--model_name", type=str, required=False, help="The checkpoint you gonna use", default="audioldm-s-full", choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2"] ) parser.add_argument( "-ckpt", "--ckpt_path", type=str, required=False, help="The path to the pretrained .ckpt model", default=None, ) parser.add_argument( "-b", "--batchsize", type=int, required=False, default=1, help="Generate how many samples at the same time", ) parser.add_argument( "--ddim_steps", type=int, required=False, default=200, help="The sampling step for DDIM", ) parser.add_argument( "-gs", "--guidance_scale", type=float, required=False, default=2.5, help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", ) parser.add_argument( "-dur", "--duration", type=float, required=False, default=10.0, help="The duration of the samples", ) parser.add_argument( "-n", "--n_candidate_gen_per_text", type=int, required=False, default=3, help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", ) parser.add_argument( "--seed", type=int, required=False, default=42, help="Change this value (any integer number) will lead to a different generation result.", ) args = parser.parse_args() if(args.ckpt_path is not None): print("Warning: ckpt_path has no effect after version 0.0.20.") assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5" mode = args.mode if(mode == "generation" and args.file_path is not None): mode = "generation_audio_to_audio" if(len(args.text) > 0): print("Warning: You have specified the --file_path. --text will be ignored") args.text = "" save_path = os.path.join(args.save_path, mode) if(args.file_path is not None): save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0])) text = args.text random_seed = args.seed duration = args.duration guidance_scale = args.guidance_scale n_candidate_gen_per_text = args.n_candidate_gen_per_text os.makedirs(save_path, exist_ok=True) audioldm = build_model(model_name=args.model_name) if(args.mode == "generation"): waveform = text_to_audio( audioldm, text, args.file_path, random_seed, duration=duration, guidance_scale=guidance_scale, ddim_steps=args.ddim_steps, n_candidate_gen_per_text=n_candidate_gen_per_text, batchsize=args.batchsize, ) elif(args.mode == "transfer"): assert args.file_path is not None assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path waveform = style_transfer( audioldm, text, args.file_path, args.transfer_strength, random_seed, duration=duration, guidance_scale=guidance_scale, ddim_steps=args.ddim_steps, batchsize=args.batchsize, ) waveform = waveform[:,None,:] save_wave(waveform, save_path, name="%s_%s" % (get_time(), text))