Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
import shutil | |
from pathlib import Path | |
import soundfile as sf | |
import torch | |
from tqdm import tqdm | |
from common.log import logger | |
from common.stdout_wrapper import SAFE_STDOUT | |
vad_model, utils = torch.hub.load( | |
repo_or_dir="snakers4/silero-vad", | |
model="silero_vad", | |
onnx=True, | |
trust_repo=True, | |
) | |
(get_speech_timestamps, _, read_audio, *_) = utils | |
def get_stamps( | |
audio_file, min_silence_dur_ms: int = 700, min_sec: float = 2, max_sec: float = 12 | |
): | |
""" | |
min_silence_dur_ms: int (ミリ秒): | |
このミリ秒数以上を無音だと判断する。 | |
逆に、この秒数以下の無音区間では区切られない。 | |
小さくすると、音声がぶつ切りに小さくなりすぎ、 | |
大きくすると音声一つ一つが長くなりすぎる。 | |
データセットによってたぶん要調整。 | |
min_sec: float (秒): | |
この秒数より小さい発話は無視する。 | |
max_sec: float (秒): | |
この秒数より大きい発話は無視する。 | |
""" | |
sampling_rate = 16000 # 16kHzか8kHzのみ対応 | |
min_ms = int(min_sec * 1000) | |
wav = read_audio(audio_file, sampling_rate=sampling_rate) | |
speech_timestamps = get_speech_timestamps( | |
wav, | |
vad_model, | |
sampling_rate=sampling_rate, | |
min_silence_duration_ms=min_silence_dur_ms, | |
min_speech_duration_ms=min_ms, | |
max_speech_duration_s=max_sec, | |
) | |
return speech_timestamps | |
def split_wav( | |
audio_file, | |
target_dir="raw", | |
min_sec=2, | |
max_sec=12, | |
min_silence_dur_ms=700, | |
): | |
margin = 200 # ミリ秒単位で、音声の前後に余裕を持たせる | |
speech_timestamps = get_stamps( | |
audio_file, | |
min_silence_dur_ms=min_silence_dur_ms, | |
min_sec=min_sec, | |
max_sec=max_sec, | |
) | |
data, sr = sf.read(audio_file) | |
total_ms = len(data) / sr * 1000 | |
file_name = os.path.basename(audio_file).split(".")[0] | |
os.makedirs(target_dir, exist_ok=True) | |
total_time_ms = 0 | |
# タイムスタンプに従って分割し、ファイルに保存 | |
for i, ts in enumerate(speech_timestamps): | |
start_ms = max(ts["start"] / 16 - margin, 0) | |
end_ms = min(ts["end"] / 16 + margin, total_ms) | |
start_sample = int(start_ms / 1000 * sr) | |
end_sample = int(end_ms / 1000 * sr) | |
segment = data[start_sample:end_sample] | |
sf.write(os.path.join(target_dir, f"{file_name}-{i}.wav"), segment, sr) | |
total_time_ms += end_ms - start_ms | |
return total_time_ms / 1000 | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--min_sec", "-m", type=float, default=2, help="Minimum seconds of a slice" | |
) | |
parser.add_argument( | |
"--max_sec", "-M", type=float, default=12, help="Maximum seconds of a slice" | |
) | |
parser.add_argument( | |
"--input_dir", | |
"-i", | |
type=str, | |
default="inputs", | |
help="Directory of input wav files", | |
) | |
parser.add_argument( | |
"--output_dir", | |
"-o", | |
type=str, | |
default="raw", | |
help="Directory of output wav files", | |
) | |
parser.add_argument( | |
"--min_silence_dur_ms", | |
"-s", | |
type=int, | |
default=700, | |
help="Silence above this duration (ms) is considered as a split point.", | |
) | |
args = parser.parse_args() | |
input_dir = args.input_dir | |
output_dir = args.output_dir | |
min_sec = args.min_sec | |
max_sec = args.max_sec | |
min_silence_dur_ms = args.min_silence_dur_ms | |
wav_files = Path(input_dir).glob("**/*.wav") | |
wav_files = list(wav_files) | |
logger.info(f"Found {len(wav_files)} wav files.") | |
if os.path.exists(output_dir): | |
logger.warning(f"Output directory {output_dir} already exists, deleting...") | |
shutil.rmtree(output_dir) | |
total_sec = 0 | |
for wav_file in tqdm(wav_files, file=SAFE_STDOUT): | |
time_sec = split_wav( | |
audio_file=str(wav_file), | |
target_dir=output_dir, | |
min_sec=min_sec, | |
max_sec=max_sec, | |
min_silence_dur_ms=min_silence_dur_ms, | |
) | |
total_sec += time_sec | |
logger.info(f"Slice done! Total time: {total_sec / 60:.2f} min.") | |