bichnhan2701's picture
update vad logic for chunks
5399362
# Audio chunking/splitting/merging logic
import shlex
import subprocess
from typing import List
from app.core.audio_utils import get_audio_info, make_temp_path
import soundfile as sf
import numpy as np
# optional webrtcvad for speech-based splitting
try:
import webrtcvad
_HAS_VAD = True
except Exception:
_HAS_VAD = False
def ffmpeg_extract_segment(src: str, start: float, duration: float, dst: str):
"""
Extract segment [start, start+duration) using ffmpeg into dst (wav 16k mono pcm16).
"""
cmd = f'ffmpeg -v error -y -ss {start:.3f} -i "{src}" -t {duration:.3f} -ar 16000 -ac 1 -acodec pcm_s16le "{dst}"'
proc = subprocess.run(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if proc.returncode != 0:
raise RuntimeError(f"ffmpeg extract failed: {proc.stderr.decode(errors='ignore')}")
return dst
def split_audio_to_chunks(src_wav: str, chunk_length_s: float = 30.0, overlap_s: float = 5.0) -> List[str]:
info = get_audio_info(src_wav)
if not info:
raise RuntimeError("Cannot read audio info")
duration = info["duration"]
step = chunk_length_s - overlap_s
if step <= 0:
raise ValueError("chunk_length_s must be > overlap_s")
starts = []
t = 0.0
while t < duration:
starts.append(t)
t += step
chunks = []
for i, s in enumerate(starts):
chunk_path = make_temp_path(suffix=f"_chunk{i}.wav")
ffmpeg_extract_segment(src_wav, s, min(chunk_length_s, duration - s), chunk_path)
chunks.append(chunk_path)
return chunks
def split_audio_with_vad(
src_wav: str,
aggressiveness: int = 2,
frame_ms: int = 30,
padding_ms: int = 300,
) -> List[str]:
"""
Split audio using webrtcvad speech detection. Returns list of chunk file paths.
Falls back to fixed-window splitting if webrtcvad is not available or audio not 16k mono.
"""
if not _HAS_VAD:
return split_audio_to_chunks(src_wav)
info = get_audio_info(src_wav)
if not info:
raise RuntimeError("Cannot read audio info for VAD split")
sr = int(info.get("samplerate", 0))
channels = int(info.get("channels", 0))
if sr != 16000 or channels != 1:
# require 16k mono for webrtcvad reliability; fallback
return split_audio_to_chunks(src_wav)
# read PCM samples
data, _ = sf.read(src_wav, dtype="int16")
if data.ndim > 1:
data = data[:, 0]
pcm_bytes = data.tobytes()
vad = webrtcvad.Vad(aggressiveness)
frame_size = int(sr * frame_ms / 1000) # samples per frame
frame_bytes = frame_size * 2
total_frames = (len(pcm_bytes) + frame_bytes - 1) // frame_bytes
speech_frames = []
for i in range(total_frames):
start = i * frame_bytes
end = start + frame_bytes
frame = pcm_bytes[start:end]
if len(frame) < frame_bytes:
# pad last frame
frame = frame.ljust(frame_bytes, b"\x00")
is_speech = False
try:
is_speech = vad.is_speech(frame, sr)
except Exception:
is_speech = False
speech_frames.append(bool(is_speech))
# group contiguous speech frames into segments
segments = []
in_speech = False
seg_start = 0
for idx, val in enumerate(speech_frames):
if val and not in_speech:
in_speech = True
seg_start = idx
elif not val and in_speech:
in_speech = False
seg_end = idx - 1
segments.append((seg_start, seg_end))
if in_speech:
segments.append((seg_start, len(speech_frames) - 1))
# merge segments if gap smaller than padding_ms
merged = []
pad_frames = int(padding_ms / frame_ms)
for seg in segments:
if not merged:
merged.append(seg)
continue
prev = merged[-1]
if seg[0] - prev[1] <= pad_frames:
merged[-1] = (prev[0], seg[1])
else:
merged.append(seg)
# convert frame indices to times and extract with ffmpeg
chunks = []
for i, (s_idx, e_idx) in enumerate(merged):
start_s = s_idx * frame_ms / 1000.0
dur = (e_idx - s_idx + 1) * frame_ms / 1000.0
chunk_path = make_temp_path(suffix=f"_vad_chunk{i}.wav")
ffmpeg_extract_segment(src_wav, start_s, dur, chunk_path)
chunks.append(chunk_path)
# If VAD found nothing, fallback to fixed windows
if not chunks:
return split_audio_to_chunks(src_wav)
return chunks