Spaces:
Running
on
T4
Running
on
T4
import librosa | |
import soundfile as sf | |
from tqdm import tqdm | |
from Preprocessing.TextFrontend import get_feature_to_index_lookup | |
from Utility.path_to_transcript_dicts import * | |
def make_silence_cleaned_versions(train_sets): | |
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround | |
# careful: assumes 16kHz or 8kHz audio | |
silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', | |
model='silero_vad', | |
force_reload=False, | |
onnx=False, | |
verbose=False) | |
(get_speech_timestamps, | |
save_audio, | |
read_audio, | |
VADIterator, | |
collect_chunks) = utils | |
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets | |
# this to false globally during model loading rather than using inference mode or no_grad | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
silero_model = silero_model.to(device) | |
for train_set in train_sets: | |
for index in tqdm(range(len(train_set))): | |
filepath = train_set.datapoints[index][8] | |
phonemes = train_set.datapoints[index][0] | |
speech_length = train_set.datapoints[index][3] | |
durations = train_set.datapoints[index][4] | |
cumsum = 0 | |
legal_silences = list() | |
for phoneme_index, phone in enumerate(phonemes): | |
if phone[get_feature_to_index_lookup()["silence"]] == 1 or phone[get_feature_to_index_lookup()["end of sentence"]] == 1 or phone[get_feature_to_index_lookup()["questionmark"]] == 1 or phone[get_feature_to_index_lookup()["exclamationmark"]] == 1 or phone[get_feature_to_index_lookup()["fullstop"]] == 1: | |
legal_silences.append([cumsum, cumsum + durations[phoneme_index]]) | |
cumsum = cumsum + durations[phoneme_index] | |
wave, sr = sf.read(filepath) | |
resampled_wave = librosa.resample(wave, orig_sr=sr, target_sr=16000) | |
with torch.inference_mode(): | |
speech_timestamps = get_speech_timestamps(torch.Tensor(resampled_wave).to(device), silero_model, sampling_rate=16000) | |
silences = list() | |
prev_end = 0 | |
for speech_segment in speech_timestamps: | |
if prev_end != 0: | |
silences.append([prev_end, speech_segment["start"]]) | |
prev_end = speech_segment["end"] | |
# at this point we know all the silences and we know the legal silences. | |
# We have to transform them both into ratios, so we can compare them. | |
# If a silence overlaps with a legal silence, it can stay. | |
illegal_silences = list() | |
for silence in silences: | |
illegal = True | |
start = silence[0] / len(resampled_wave) | |
end = silence[1] / len(resampled_wave) | |
for legal_silence in legal_silences: | |
legal_start = legal_silence[0] / speech_length | |
legal_end = legal_silence[1] / speech_length | |
if legal_start < start < legal_end or legal_start < end < legal_end: | |
illegal = False | |
break | |
if illegal: | |
# If it is an illegal silence, it is marked for removal in the original wave according to ration with real samplingrate. | |
illegal_silences.append([start, end]) | |
# print(f"{len(illegal_silences)} illegal silences detected. ({len(silences) - len(illegal_silences)} legal silences left)") | |
wave = list(wave) | |
orig_wave_length = len(wave) | |
for illegal_silence in reversed(illegal_silences): | |
wave = wave[:int(illegal_silence[0] * orig_wave_length)] + wave[int(illegal_silence[1] * orig_wave_length):] | |
# Audio with illegal silences removed will be saved into a new directory. | |
new_filepath_list = filepath.split("/") | |
new_filepath_list[-2] = new_filepath_list[-2] + "_silence_removed" | |
os.makedirs("/".join(new_filepath_list[:-1]), exist_ok=True) | |
sf.write("/".join(new_filepath_list), wave, sr) | |