import argparse import os import shutil from pathlib import Path import soundfile as sf import torch from tqdm import tqdm from common.log import logger from common.stdout_wrapper import SAFE_STDOUT vad_model, utils = torch.hub.load( repo_or_dir="snakers4/silero-vad", model="silero_vad", onnx=True, trust_repo=True, ) (get_speech_timestamps, _, read_audio, *_) = utils def get_stamps( audio_file, min_silence_dur_ms: int = 700, min_sec: float = 2, max_sec: float = 12 ): """ min_silence_dur_ms: int (ミリ秒): このミリ秒数以上を無音だと判断する。 逆に、この秒数以下の無音区間では区切られない。 小さくすると、音声がぶつ切りに小さくなりすぎ、 大きくすると音声一つ一つが長くなりすぎる。 データセットによってたぶん要調整。 min_sec: float (秒): この秒数より小さい発話は無視する。 max_sec: float (秒): この秒数より大きい発話は無視する。 """ sampling_rate = 16000 # 16kHzか8kHzのみ対応 min_ms = int(min_sec * 1000) wav = read_audio(audio_file, sampling_rate=sampling_rate) speech_timestamps = get_speech_timestamps( wav, vad_model, sampling_rate=sampling_rate, min_silence_duration_ms=min_silence_dur_ms, min_speech_duration_ms=min_ms, max_speech_duration_s=max_sec, ) return speech_timestamps def split_wav( audio_file, target_dir="raw", min_sec=2, max_sec=12, min_silence_dur_ms=700, ): margin = 200 # ミリ秒単位で、音声の前後に余裕を持たせる speech_timestamps = get_stamps( audio_file, min_silence_dur_ms=min_silence_dur_ms, min_sec=min_sec, max_sec=max_sec, ) data, sr = sf.read(audio_file) total_ms = len(data) / sr * 1000 file_name = os.path.basename(audio_file).split(".")[0] os.makedirs(target_dir, exist_ok=True) total_time_ms = 0 # タイムスタンプに従って分割し、ファイルに保存 for i, ts in enumerate(speech_timestamps): start_ms = max(ts["start"] / 16 - margin, 0) end_ms = min(ts["end"] / 16 + margin, total_ms) start_sample = int(start_ms / 1000 * sr) end_sample = int(end_ms / 1000 * sr) segment = data[start_sample:end_sample] sf.write(os.path.join(target_dir, f"{file_name}-{i}.wav"), segment, sr) total_time_ms += end_ms - start_ms return total_time_ms / 1000 if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--min_sec", "-m", type=float, default=2, help="Minimum seconds of a slice" ) parser.add_argument( "--max_sec", "-M", type=float, default=12, help="Maximum seconds of a slice" ) parser.add_argument( "--input_dir", "-i", type=str, default="inputs", help="Directory of input wav files", ) parser.add_argument( "--output_dir", "-o", type=str, default="raw", help="Directory of output wav files", ) parser.add_argument( "--min_silence_dur_ms", "-s", type=int, default=700, help="Silence above this duration (ms) is considered as a split point.", ) args = parser.parse_args() input_dir = args.input_dir output_dir = args.output_dir min_sec = args.min_sec max_sec = args.max_sec min_silence_dur_ms = args.min_silence_dur_ms wav_files = Path(input_dir).glob("**/*.wav") wav_files = list(wav_files) logger.info(f"Found {len(wav_files)} wav files.") if os.path.exists(output_dir): logger.warning(f"Output directory {output_dir} already exists, deleting...") shutil.rmtree(output_dir) total_sec = 0 for wav_file in tqdm(wav_files, file=SAFE_STDOUT): time_sec = split_wav( audio_file=str(wav_file), target_dir=output_dir, min_sec=min_sec, max_sec=max_sec, min_silence_dur_ms=min_silence_dur_ms, ) total_sec += time_sec logger.info(f"Slice done! Total time: {total_sec / 60:.2f} min.")