Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
457ae0e
1
Parent(s):
5a5c0e1
Refactor VAD initialization and processing logic
Browse files- Added global VAD instance and locking mechanism.
- Updated RealtimeVAD to use global VAD for processing.
- Implemented warmup procedure for VAD initialization.
- webrtc_vad.py +54 -8
webrtc_vad.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from typing import Callable, Generator, override
|
| 3 |
|
|
@@ -13,11 +14,31 @@ class VADEvent:
|
|
| 13 |
full_audio: tuple[int, np.ndarray] | None = None
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class RealtimeVAD:
|
| 17 |
def __init__(
|
| 18 |
self,
|
| 19 |
src_sr: int = 24000,
|
| 20 |
-
hop_size: int = 256,
|
| 21 |
start_threshold: float = 0.8,
|
| 22 |
end_threshold: float = 0.7,
|
| 23 |
pad_start_s: float = 0.6,
|
|
@@ -26,15 +47,13 @@ class RealtimeVAD:
|
|
| 26 |
):
|
| 27 |
self.src_sr = src_sr
|
| 28 |
self.vad_sr = 16000
|
| 29 |
-
self.hop_size =
|
| 30 |
self.start_threshold = start_threshold
|
| 31 |
self.end_threshold = end_threshold
|
| 32 |
self.pad_start_s = pad_start_s
|
| 33 |
self.min_positive_s = min_positive_s
|
| 34 |
self.min_silence_s = min_silence_s
|
| 35 |
|
| 36 |
-
self.vad_model = TenVad(hop_size=hop_size)
|
| 37 |
-
|
| 38 |
self.vad_buffer = np.array([], dtype=np.int16)
|
| 39 |
"""
|
| 40 |
VAD Buffer to store audio data for VAD processing
|
|
@@ -56,9 +75,6 @@ class RealtimeVAD:
|
|
| 56 |
self.sum_positive_s = 0.0
|
| 57 |
self.silence_start_s: float | None = None
|
| 58 |
|
| 59 |
-
# Warmup
|
| 60 |
-
self.vad_model.process(np.zeros(hop_size, dtype=np.int16))
|
| 61 |
-
|
| 62 |
def process(self, audio_data: np.ndarray):
|
| 63 |
if audio_data.ndim == 2:
|
| 64 |
# FastRTC style [channels, samples]
|
|
@@ -77,7 +93,7 @@ class RealtimeVAD:
|
|
| 77 |
vad_buffer_size = self.vad_buffer.shape[0]
|
| 78 |
|
| 79 |
def process_chunk(chunk_offset_s: float, vad_chunk: np.ndarray):
|
| 80 |
-
speech_prob
|
| 81 |
|
| 82 |
hop_s = self.hop_size / self.vad_sr
|
| 83 |
|
|
@@ -133,6 +149,7 @@ class RealtimeVAD:
|
|
| 133 |
self.sum_positive_s = 0.0
|
| 134 |
self.silence_start_s = None
|
| 135 |
|
|
|
|
| 136 |
for chunk_pos in range(0, vad_buffer_size - self.hop_size, self.hop_size):
|
| 137 |
processed_samples = chunk_pos + self.hop_size
|
| 138 |
chunk_offset_s = (self.vad_buffer_offset + chunk_pos) / self.vad_sr
|
|
@@ -143,6 +160,33 @@ class RealtimeVAD:
|
|
| 143 |
self.vad_buffer_offset += processed_samples
|
| 144 |
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
type StreamerGenerator = Generator[fastrtc.tracks.EmitType, None, None]
|
| 147 |
type StreamerFn = Callable[[tuple[int, np.ndarray], str], StreamerGenerator]
|
| 148 |
|
|
@@ -164,6 +208,8 @@ class VADStreamHandler(fastrtc.StreamHandler):
|
|
| 164 |
self.realtime_vad = RealtimeVAD(src_sr=input_sample_rate)
|
| 165 |
self.generator: StreamerGenerator | None = None
|
| 166 |
|
|
|
|
|
|
|
| 167 |
@override
|
| 168 |
def emit(self) -> fastrtc.tracks.EmitType:
|
| 169 |
if self.generator is None:
|
|
|
|
| 1 |
+
import threading
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from typing import Callable, Generator, override
|
| 4 |
|
|
|
|
| 14 |
full_audio: tuple[int, np.ndarray] | None = None
|
| 15 |
|
| 16 |
|
| 17 |
+
global_ten_vad: TenVad | None = None
|
| 18 |
+
global_vad_lock = threading.Lock()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def global_vad_process(audio_data: np.ndarray) -> float:
|
| 22 |
+
"""
|
| 23 |
+
Process audio data (hop_size=256) with global TenVad instance.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
speech probability.
|
| 27 |
+
"""
|
| 28 |
+
global global_ten_vad
|
| 29 |
+
|
| 30 |
+
with global_vad_lock:
|
| 31 |
+
if global_ten_vad is None:
|
| 32 |
+
global_ten_vad = TenVad()
|
| 33 |
+
|
| 34 |
+
prob, _ = global_ten_vad.process(audio_data)
|
| 35 |
+
return prob
|
| 36 |
+
|
| 37 |
+
|
| 38 |
class RealtimeVAD:
|
| 39 |
def __init__(
|
| 40 |
self,
|
| 41 |
src_sr: int = 24000,
|
|
|
|
| 42 |
start_threshold: float = 0.8,
|
| 43 |
end_threshold: float = 0.7,
|
| 44 |
pad_start_s: float = 0.6,
|
|
|
|
| 47 |
):
|
| 48 |
self.src_sr = src_sr
|
| 49 |
self.vad_sr = 16000
|
| 50 |
+
self.hop_size = 256
|
| 51 |
self.start_threshold = start_threshold
|
| 52 |
self.end_threshold = end_threshold
|
| 53 |
self.pad_start_s = pad_start_s
|
| 54 |
self.min_positive_s = min_positive_s
|
| 55 |
self.min_silence_s = min_silence_s
|
| 56 |
|
|
|
|
|
|
|
| 57 |
self.vad_buffer = np.array([], dtype=np.int16)
|
| 58 |
"""
|
| 59 |
VAD Buffer to store audio data for VAD processing
|
|
|
|
| 75 |
self.sum_positive_s = 0.0
|
| 76 |
self.silence_start_s: float | None = None
|
| 77 |
|
|
|
|
|
|
|
|
|
|
| 78 |
def process(self, audio_data: np.ndarray):
|
| 79 |
if audio_data.ndim == 2:
|
| 80 |
# FastRTC style [channels, samples]
|
|
|
|
| 93 |
vad_buffer_size = self.vad_buffer.shape[0]
|
| 94 |
|
| 95 |
def process_chunk(chunk_offset_s: float, vad_chunk: np.ndarray):
|
| 96 |
+
speech_prob = global_vad_process(vad_chunk)
|
| 97 |
|
| 98 |
hop_s = self.hop_size / self.vad_sr
|
| 99 |
|
|
|
|
| 149 |
self.sum_positive_s = 0.0
|
| 150 |
self.silence_start_s = None
|
| 151 |
|
| 152 |
+
processed_samples = 0
|
| 153 |
for chunk_pos in range(0, vad_buffer_size - self.hop_size, self.hop_size):
|
| 154 |
processed_samples = chunk_pos + self.hop_size
|
| 155 |
chunk_offset_s = (self.vad_buffer_offset + chunk_pos) / self.vad_sr
|
|
|
|
| 160 |
self.vad_buffer_offset += processed_samples
|
| 161 |
|
| 162 |
|
| 163 |
+
def init_global_ten_vad(input_sample_rate: int = 24000):
|
| 164 |
+
"""
|
| 165 |
+
Call this once at the start of the program to avoid latency on first use.
|
| 166 |
+
No-op if already initialized.
|
| 167 |
+
"""
|
| 168 |
+
global global_ten_vad
|
| 169 |
+
|
| 170 |
+
require_warmup = False
|
| 171 |
+
|
| 172 |
+
with global_vad_lock:
|
| 173 |
+
if global_ten_vad is None:
|
| 174 |
+
global_ten_vad = TenVad()
|
| 175 |
+
|
| 176 |
+
require_warmup = True
|
| 177 |
+
|
| 178 |
+
if require_warmup:
|
| 179 |
+
print("[VAD] Initializing global TenVad...")
|
| 180 |
+
|
| 181 |
+
realtime_vad = RealtimeVAD(src_sr=input_sample_rate)
|
| 182 |
+
|
| 183 |
+
for _ in range(25): # Warmup with 1 second of silence
|
| 184 |
+
for _ in realtime_vad.process(np.zeros(960, dtype=np.int16)):
|
| 185 |
+
pass
|
| 186 |
+
|
| 187 |
+
print("[VAD] Global VAD initialized")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
type StreamerGenerator = Generator[fastrtc.tracks.EmitType, None, None]
|
| 191 |
type StreamerFn = Callable[[tuple[int, np.ndarray], str], StreamerGenerator]
|
| 192 |
|
|
|
|
| 208 |
self.realtime_vad = RealtimeVAD(src_sr=input_sample_rate)
|
| 209 |
self.generator: StreamerGenerator | None = None
|
| 210 |
|
| 211 |
+
init_global_ten_vad()
|
| 212 |
+
|
| 213 |
@override
|
| 214 |
def emit(self) -> fastrtc.tracks.EmitType:
|
| 215 |
if self.generator is None:
|