import argparse import pathlib import tqdm from torch.utils.data import Dataset, DataLoader from score import loadWav, Score import torch import os import warnings warnings.filterwarnings("ignore") def get_arg(): parser = argparse.ArgumentParser() parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str, help="predict mode") parser.add_argument("--ckpt_path", required=False, default="voxsim_wavlm_ecapa.model", type=pathlib.Path, help="path to the model checkpoint") parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path, help="input directory when predict_dir mode") parser.add_argument("--ref_dir", required=False, default=None, type=pathlib.Path, help="reference directory when predict_dir mode") parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path, help="input file when predict_file mode") parser.add_argument("--ref_path", required=False, default=None, type=pathlib.Path, help="reference file when predict_file mode") parser.add_argument("--out_path", required=True, type=pathlib.Path, help="output path") parser.add_argument("--num_workers", required=False, default=4, type=int, help="number of workers for dataloader") return parser.parse_args() class AudioDataset(Dataset): def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400): self.inp_dir_path = inp_dir_path self.ref_dir_path = ref_dir_path self.inp_wavlist = [file for file in os.listdir(inp_dir_path) if file.endswith(".wav")] inp_wavset = set(self.inp_wavlist) ref_wavset = set([file for file in os.listdir(ref_dir_path) if file.endswith(".wav")]) diff = inp_wavset - ref_wavset if diff: diff = list(diff) diff.sort() raise ValueError(f"Files {diff} are in inp_dir but not in ref_dir.") self.inp_wavlist.sort() self.max_audio = max_frames * 160 + 240 def __len__(self): return len(self.inp_wavlist) def __getitem__(self, idx): inp_wavs, inp_wav = loadWav(os.path.join(self.inp_dir_path, self.inp_wavlist[idx])) ref_wavs, ref_wav = loadWav(os.path.join(self.ref_dir_path, self.inp_wavlist[idx])) return inp_wavs, inp_wav, ref_wavs, ref_wav def main(): args = get_arg() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.mode == "predict_file": assert args.inp_path is not None, "inp_path is required when mode is predict_file." assert args.ref_path is not None, "ref_path is required when mode is predict_file." assert args.inp_path.exists() assert args.ref_path.exists() assert args.inp_path.is_file() assert args.ref_path.is_file() inp_wavs, inp_wav = loadWav(args.inp_path) ref_wavs, ref_wav = loadWav(args.ref_path) scorer = Score(ckpt_path=args.ckpt_path, device=device) score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav) print("VoxSIM score: ", score) with open(args.out_path, "w") as fw: fw.write(str(score)) else: assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir." assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir." assert args.inp_dir.exists() assert args.ref_dir.exists() assert args.inp_dir.is_dir() assert args.ref_dir.is_dir() dataset = AudioDataset(args.inp_dir, args.ref_dir) loader = DataLoader( dataset, batch_size=1, shuffle=False, num_workers=args.num_workers) scorer = Score(ckpt_path=args.ckpt_path, device=device) avg_score = [] with open(args.out_path, 'w') as fw: for batch in tqdm.tqdm(loader): inp_wavs, inp_wav, ref_wavs, ref_wav = batch score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav) avg_score.append(score) fw.write(str(score) + "\n") print("Average VoxSIM score: ", sum(avg_score)/len(avg_score)) print("save to ", args.out_path) if __name__ == "__main__": main()