File size: 3,088 Bytes
fa90792 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
#!/usr/bin/python3
import os
import torch
import logging
from audiosr import super_resolution, build_model, save_wave, get_time, read_list
import argparse
os.environ["TOKENIZERS_PARALLELISM"] = "true"
matplotlib_logger = logging.getLogger('matplotlib')
matplotlib_logger.setLevel(logging.WARNING)
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input_audio_file",
type=str,
required=False,
help="Input audio file for audio super resolution",
)
parser.add_argument(
"-il",
"--input_file_list",
type=str,
required=False,
default="",
help="A file that contains all audio files that need to perform audio super resolution",
)
parser.add_argument(
"-s",
"--save_path",
type=str,
required=False,
help="The path to save model output",
default="./output",
)
parser.add_argument(
"--model_name",
type=str,
required=False,
help="The checkpoint you gonna use",
default="basic",
choices=["basic","speech"]
)
parser.add_argument(
"-d",
"--device",
type=str,
required=False,
help="The device for computation. If not specified, the script will automatically choose the device based on your environment.",
default="auto",
)
parser.add_argument(
"--ddim_steps",
type=int,
required=False,
default=50,
help="The sampling step for DDIM",
)
parser.add_argument(
"-gs",
"--guidance_scale",
type=float,
required=False,
default=3.5,
help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
)
parser.add_argument(
"--seed",
type=int,
required=False,
default=42,
help="Changing this value (any integer number) will lead to a different generation result.",
)
parser.add_argument(
"--suffix",
type=str,
required=False,
help="Suffix for the output file",
default="_AudioSR_Processed_48K",
)
args = parser.parse_args()
torch.set_float32_matmul_precision("high")
save_path = os.path.join(args.save_path, get_time())
assert args.input_file_list is not None or args.input_audio_file is not None,"Please provide either a list of audio files or a single audio file"
input_file = args.input_audio_file
random_seed = args.seed
sample_rate=48000
latent_t_per_second=12.8
guidance_scale = args.guidance_scale
os.makedirs(save_path, exist_ok=True)
audiosr = build_model(model_name=args.model_name, device=args.device)
if(args.input_file_list):
print("Generate audio based on the text prompts in %s" % args.input_file_list)
files_todo = read_list(args.input_file_list)
else:
files_todo = [input_file]
for input_file in files_todo:
name = os.path.splitext(os.path.basename(input_file))[0] + args.suffix
waveform = super_resolution(
audiosr,
input_file,
seed=random_seed,
guidance_scale=guidance_scale,
ddim_steps=args.ddim_steps,
latent_t_per_second=latent_t_per_second
)
save_wave(waveform, inputpath=input_file, savepath=save_path, name=name, samplerate=sample_rate)
|