| | |
| | |
| | |
| | |
| | |
| | |
| | import os, glob, sys, torchaudio |
| | import numpy as np |
| | import scipy.io.wavfile as Wavfile |
| | import matplotlib.pyplot as plt |
| | from sklearn.metrics import confusion_matrix |
| |
|
| | os.system('git clone https://github.com/snakers4/silero-vad.git') |
| | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "./silero-vad/src"))) |
| | from silero_vad.utils_vad import VADIterator, init_jit_model |
| |
|
| | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../include"))) |
| | from ten_vad import TenVad |
| |
|
| | def convert_label_to_framewise(label_file, hop_size): |
| | frame_duration = hop_size / 16000 |
| | with open(label_file, "r") as f: |
| | lines = f.readlines() |
| | content = lines[0].strip().split(",")[1:] |
| | start = np.array( |
| | content[::3], dtype=float |
| | ) |
| | end = np.array( |
| | content[1:][::3], dtype=float |
| | ) |
| | lab_manual = np.array( |
| | content[2:][::3], dtype=int |
| | ) |
| | assert ( |
| | len(start) == len(end) |
| | and len(start) == len(lab_manual) |
| | and len(end) == len(lab_manual) |
| | ) |
| | |
| | num = np.array( |
| | np.round(((end - start) / frame_duration)), dtype=np.int32 |
| | ) |
| | label_framewise = np.array([]) |
| | for segment_idx in range(len(num)): |
| | cur_lab = int(lab_manual[segment_idx]) |
| | num_segment = num[segment_idx] |
| |
|
| | if cur_lab == 1: |
| | vad_result_this_segment = np.ones(num_segment) |
| | elif cur_lab == 0: |
| | vad_result_this_segment = np.zeros(num_segment) |
| | label_framewise = np.append(label_framewise, vad_result_this_segment) |
| | frame_num = min( |
| | label_framewise.__len__(), int((end[-1] - start[0]) / frame_duration) |
| | ) |
| | label_framewise = label_framewise[:frame_num] |
| |
|
| | return label_framewise |
| |
|
| |
|
| | def read_file(file_path): |
| | with open(file_path, "r") as f: |
| | lines = f.readlines() |
| | lines_arr = np.array([]) |
| | for line in lines: |
| | lines_arr = np.append(lines_arr, float(line.strip())) |
| |
|
| | return lines_arr |
| |
|
| | def get_precision_recall(VAD_result, label, threshold): |
| | vad_result_hard = np.where(VAD_result >= threshold, 1, 0) |
| |
|
| | |
| | TN, FP, FN, TP = confusion_matrix(label, vad_result_hard).ravel() |
| |
|
| | |
| | precision = TP / (TP + FP) if (TP + FP) > 0 else 0 |
| | recall = TP / (TP + FN) if (TP + FN) > 0 else 0 |
| | FPR = FP / (FP + TN) if (FP + TN) > 0 else 0 |
| | FNR = FN / (TP + FN) if (TP + FN) > 0 else 0 |
| |
|
| | return precision, recall, FPR, FNR |
| |
|
| | def silero_vad_inference_single_file(wav_path): |
| | current_directory = os.path.dirname(os.path.abspath(__file__)) |
| | model = init_jit_model(f'{current_directory}/silero-vad/src/silero_vad/data/silero_vad.jit') |
| | vad_iterator = VADIterator(model) |
| | window_size_samples = 512 |
| | speech_probs = np.array([]) |
| | |
| | wav, sr = torchaudio.load(wav_path) |
| | wav = wav.squeeze(0) |
| | for i in range(0, len(wav), window_size_samples): |
| | chunk = wav[i: i+ window_size_samples] |
| | if len(chunk) < window_size_samples: |
| | break |
| | speech_prob = model(chunk, sr).item() |
| | speech_probs = np.append(speech_probs, speech_prob) |
| | vad_iterator.reset_states() |
| | |
| | return speech_probs, window_size_samples |
| |
|
| | def ten_vad_process_wav(ten_vad_instance, wav_path, hop_size=256): |
| | _, data = Wavfile.read(wav_path) |
| | num_frames = data.shape[0] // hop_size |
| | voice_prob_arr = np.array([]) |
| | for i in range(num_frames): |
| | input_data = data[i * hop_size: (i + 1) * hop_size] |
| | voice_prob, _ = ten_vad_instance.process(input_data) |
| | voice_prob_arr = np.append(voice_prob_arr, voice_prob) |
| |
|
| | return voice_prob_arr |
| |
|
| | if __name__ == "__main__": |
| | |
| | script_dir = os.path.dirname(os.path.abspath(__file__)) |
| |
|
| | |
| | test_dir = f"{script_dir}/../testset" |
| |
|
| | |
| | hop_size = 256 |
| | threshold = 0.5 |
| | label_all, vad_result_ten_vad_all = np.array([]), np.array([]) |
| | label_hop_512_all, vad_result_silero_vad_all = np.array([]), np.array([]) |
| | wav_list = glob.glob(f"{test_dir}/*.wav") |
| |
|
| | |
| | print("Start processing") |
| | for wav_path in wav_list: |
| | |
| | ten_vad_instance = TenVad(hop_size, threshold) |
| | label_file = wav_path.replace(".wav", ".scv") |
| | label = convert_label_to_framewise( |
| | label_file, hop_size=hop_size |
| | ) |
| | vad_result_ten_vad = ten_vad_process_wav( |
| | ten_vad_instance, wav_path, hop_size=hop_size |
| | ) |
| | frame_num = min(label.__len__(), vad_result_ten_vad.__len__()) |
| | vad_result_ten_vad_all = np.append( |
| | vad_result_ten_vad_all, vad_result_ten_vad[1:frame_num] |
| | ) |
| | label_all = np.append(label_all, label[:frame_num - 1]) |
| | del ten_vad_instance |
| |
|
| | |
| | label_hop_512 = convert_label_to_framewise( |
| | label_file, hop_size=512 |
| | ) |
| | vad_result_silero_vad, _ = silero_vad_inference_single_file(wav_path) |
| | frame_num_silero_vad = min(label_hop_512.__len__(), vad_result_silero_vad.__len__()) |
| | vad_result_silero_vad_all = np.append(vad_result_silero_vad_all, vad_result_silero_vad[:frame_num_silero_vad]) |
| | label_hop_512_all = np.append(label_hop_512_all, label_hop_512[:frame_num_silero_vad]) |
| |
|
| | |
| | threshold_arr = np.arange(0, 1.01, 0.01) |
| | pr_data_arr = np.zeros((threshold_arr.__len__(), 3)) |
| | pr_data_silero_vad_arr = np.zeros((threshold_arr.__len__(), 3)) |
| |
|
| | for ind, threshold in enumerate(threshold_arr): |
| | precision, recall, FPR, FNR = get_precision_recall(vad_result_ten_vad_all, label_all, threshold) |
| | pr_data_arr[ind] = precision, recall, threshold |
| |
|
| | precision_silero_vad, recall_silero_vad, FPR_silero_vad, FNR_silero_vad = get_precision_recall(vad_result_silero_vad_all, label_hop_512_all, threshold) |
| | pr_data_silero_vad_arr[ind] = precision_silero_vad, recall_silero_vad, threshold |
| |
|
| | |
| | print("Plotting PR Curve") |
| | pr_data_arr_to_plot = pr_data_arr[:-1] |
| | plt.plot( |
| | pr_data_arr_to_plot[:, 1], |
| | pr_data_arr_to_plot[:, 0], |
| | color="red", |
| | label="TEN VAD", |
| | ) |
| | pr_data_silero_vad_arr_to_plot = pr_data_silero_vad_arr[:-1] |
| | plt.plot( |
| | pr_data_silero_vad_arr_to_plot[:, 1], |
| | pr_data_silero_vad_arr_to_plot[:, 0], |
| | color="blue", |
| | label="Silero VAD", |
| | ) |
| |
|
| | plt.xlabel("Recall", fontsize=14, fontweight="bold", color="black") |
| | plt.ylabel("Precision", fontsize=14, fontweight="bold", color="black") |
| | legend = plt.legend() |
| | legend.get_texts()[0].set_fontweight("bold") |
| | legend.get_texts()[1].set_fontweight("bold") |
| | plt.grid(True) |
| | plt.xlim(0.65, 1) |
| | plt.ylim(0.7, 1) |
| | plt.title( |
| | "Precision-Recall Curve of TEN VAD on TEN-VAD-TestSet", |
| | fontsize=12, |
| | color="black", |
| | fontweight="bold", |
| | ) |
| | save_path = f"{script_dir}/PR_Curves.png" |
| | plt.savefig(save_path, dpi=300, bbox_inches="tight") |
| | print(f"PR Curves png file saved, save path: {save_path}") |
| |
|
| | |
| | pr_data_save_path = f"{script_dir}/PR_data_TEN_VAD.txt" |
| | with open(pr_data_save_path, "w") as f: |
| | for ind in range(pr_data_arr.shape[0]): |
| | precision, recall, threshold = ( |
| | pr_data_arr[ind, 0], |
| | pr_data_arr[ind, 1], |
| | pr_data_arr[ind, 2], |
| | ) |
| | f.write(f"{threshold:.2f} {precision:.4f} {recall:.4f}\n") |
| | print("Processing done!") |
| |
|
| |
|
| |
|