Spaces:
Sleeping
Sleeping
import argparse | |
import pathlib | |
import tqdm | |
from torch.utils.data import Dataset, DataLoader | |
from score import loadWav, Score | |
import torch | |
import os | |
import warnings | |
warnings.filterwarnings("ignore") | |
def get_arg(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str, help="predict mode") | |
parser.add_argument("--ckpt_path", required=False, default="voxsim_wavlm_ecapa.model", type=pathlib.Path, help="path to the model checkpoint") | |
parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path, help="input directory when predict_dir mode") | |
parser.add_argument("--ref_dir", required=False, default=None, type=pathlib.Path, help="reference directory when predict_dir mode") | |
parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path, help="input file when predict_file mode") | |
parser.add_argument("--ref_path", required=False, default=None, type=pathlib.Path, help="reference file when predict_file mode") | |
parser.add_argument("--out_path", required=True, type=pathlib.Path, help="output path") | |
parser.add_argument("--num_workers", required=False, default=4, type=int, help="number of workers for dataloader") | |
return parser.parse_args() | |
class AudioDataset(Dataset): | |
def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400): | |
self.inp_dir_path = inp_dir_path | |
self.ref_dir_path = ref_dir_path | |
self.inp_wavlist = [file for file in os.listdir(inp_dir_path) if file.endswith(".wav")] | |
inp_wavset = set(self.inp_wavlist) | |
ref_wavset = set([file for file in os.listdir(ref_dir_path) if file.endswith(".wav")]) | |
diff = inp_wavset - ref_wavset | |
if diff: | |
diff = list(diff) | |
diff.sort() | |
raise ValueError(f"Files {diff} are in inp_dir but not in ref_dir.") | |
self.inp_wavlist.sort() | |
self.max_audio = max_frames * 160 + 240 | |
def __len__(self): | |
return len(self.inp_wavlist) | |
def __getitem__(self, idx): | |
inp_wavs, inp_wav = loadWav(os.path.join(self.inp_dir_path, self.inp_wavlist[idx])) | |
ref_wavs, ref_wav = loadWav(os.path.join(self.ref_dir_path, self.inp_wavlist[idx])) | |
return inp_wavs, inp_wav, ref_wavs, ref_wav | |
def main(): | |
args = get_arg() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if args.mode == "predict_file": | |
assert args.inp_path is not None, "inp_path is required when mode is predict_file." | |
assert args.ref_path is not None, "ref_path is required when mode is predict_file." | |
assert args.inp_path.exists() | |
assert args.ref_path.exists() | |
assert args.inp_path.is_file() | |
assert args.ref_path.is_file() | |
inp_wavs, inp_wav = loadWav(args.inp_path) | |
ref_wavs, ref_wav = loadWav(args.ref_path) | |
scorer = Score(ckpt_path=args.ckpt_path, device=device) | |
score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav) | |
print("VoxSIM score: ", score) | |
with open(args.out_path, "w") as fw: | |
fw.write(str(score)) | |
else: | |
assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir." | |
assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir." | |
assert args.inp_dir.exists() | |
assert args.ref_dir.exists() | |
assert args.inp_dir.is_dir() | |
assert args.ref_dir.is_dir() | |
dataset = AudioDataset(args.inp_dir, args.ref_dir) | |
loader = DataLoader( | |
dataset, | |
batch_size=1, | |
shuffle=False, | |
num_workers=args.num_workers) | |
scorer = Score(ckpt_path=args.ckpt_path, device=device) | |
avg_score = [] | |
with open(args.out_path, 'w') as fw: | |
for batch in tqdm.tqdm(loader): | |
inp_wavs, inp_wav, ref_wavs, ref_wav = batch | |
score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav) | |
avg_score.append(score) | |
fw.write(str(score) + "\n") | |
print("Average VoxSIM score: ", sum(avg_score)/len(avg_score)) | |
print("save to ", args.out_path) | |
if __name__ == "__main__": | |
main() |