yanyihan-xiaomi commited on
Commit
457ae0e
·
1 Parent(s): 5a5c0e1

Refactor VAD initialization and processing logic

Browse files

- Added global VAD instance and locking mechanism.
- Updated RealtimeVAD to use global VAD for processing.
- Implemented warmup procedure for VAD initialization.

Files changed (1) hide show
  1. webrtc_vad.py +54 -8
webrtc_vad.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from dataclasses import dataclass
2
  from typing import Callable, Generator, override
3
 
@@ -13,11 +14,31 @@ class VADEvent:
13
  full_audio: tuple[int, np.ndarray] | None = None
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class RealtimeVAD:
17
  def __init__(
18
  self,
19
  src_sr: int = 24000,
20
- hop_size: int = 256,
21
  start_threshold: float = 0.8,
22
  end_threshold: float = 0.7,
23
  pad_start_s: float = 0.6,
@@ -26,15 +47,13 @@ class RealtimeVAD:
26
  ):
27
  self.src_sr = src_sr
28
  self.vad_sr = 16000
29
- self.hop_size = hop_size
30
  self.start_threshold = start_threshold
31
  self.end_threshold = end_threshold
32
  self.pad_start_s = pad_start_s
33
  self.min_positive_s = min_positive_s
34
  self.min_silence_s = min_silence_s
35
 
36
- self.vad_model = TenVad(hop_size=hop_size)
37
-
38
  self.vad_buffer = np.array([], dtype=np.int16)
39
  """
40
  VAD Buffer to store audio data for VAD processing
@@ -56,9 +75,6 @@ class RealtimeVAD:
56
  self.sum_positive_s = 0.0
57
  self.silence_start_s: float | None = None
58
 
59
- # Warmup
60
- self.vad_model.process(np.zeros(hop_size, dtype=np.int16))
61
-
62
  def process(self, audio_data: np.ndarray):
63
  if audio_data.ndim == 2:
64
  # FastRTC style [channels, samples]
@@ -77,7 +93,7 @@ class RealtimeVAD:
77
  vad_buffer_size = self.vad_buffer.shape[0]
78
 
79
  def process_chunk(chunk_offset_s: float, vad_chunk: np.ndarray):
80
- speech_prob, _ = self.vad_model.process(vad_chunk)
81
 
82
  hop_s = self.hop_size / self.vad_sr
83
 
@@ -133,6 +149,7 @@ class RealtimeVAD:
133
  self.sum_positive_s = 0.0
134
  self.silence_start_s = None
135
 
 
136
  for chunk_pos in range(0, vad_buffer_size - self.hop_size, self.hop_size):
137
  processed_samples = chunk_pos + self.hop_size
138
  chunk_offset_s = (self.vad_buffer_offset + chunk_pos) / self.vad_sr
@@ -143,6 +160,33 @@ class RealtimeVAD:
143
  self.vad_buffer_offset += processed_samples
144
 
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  type StreamerGenerator = Generator[fastrtc.tracks.EmitType, None, None]
147
  type StreamerFn = Callable[[tuple[int, np.ndarray], str], StreamerGenerator]
148
 
@@ -164,6 +208,8 @@ class VADStreamHandler(fastrtc.StreamHandler):
164
  self.realtime_vad = RealtimeVAD(src_sr=input_sample_rate)
165
  self.generator: StreamerGenerator | None = None
166
 
 
 
167
  @override
168
  def emit(self) -> fastrtc.tracks.EmitType:
169
  if self.generator is None:
 
1
+ import threading
2
  from dataclasses import dataclass
3
  from typing import Callable, Generator, override
4
 
 
14
  full_audio: tuple[int, np.ndarray] | None = None
15
 
16
 
17
+ global_ten_vad: TenVad | None = None
18
+ global_vad_lock = threading.Lock()
19
+
20
+
21
+ def global_vad_process(audio_data: np.ndarray) -> float:
22
+ """
23
+ Process audio data (hop_size=256) with global TenVad instance.
24
+
25
+ Returns:
26
+ speech probability.
27
+ """
28
+ global global_ten_vad
29
+
30
+ with global_vad_lock:
31
+ if global_ten_vad is None:
32
+ global_ten_vad = TenVad()
33
+
34
+ prob, _ = global_ten_vad.process(audio_data)
35
+ return prob
36
+
37
+
38
  class RealtimeVAD:
39
  def __init__(
40
  self,
41
  src_sr: int = 24000,
 
42
  start_threshold: float = 0.8,
43
  end_threshold: float = 0.7,
44
  pad_start_s: float = 0.6,
 
47
  ):
48
  self.src_sr = src_sr
49
  self.vad_sr = 16000
50
+ self.hop_size = 256
51
  self.start_threshold = start_threshold
52
  self.end_threshold = end_threshold
53
  self.pad_start_s = pad_start_s
54
  self.min_positive_s = min_positive_s
55
  self.min_silence_s = min_silence_s
56
 
 
 
57
  self.vad_buffer = np.array([], dtype=np.int16)
58
  """
59
  VAD Buffer to store audio data for VAD processing
 
75
  self.sum_positive_s = 0.0
76
  self.silence_start_s: float | None = None
77
 
 
 
 
78
  def process(self, audio_data: np.ndarray):
79
  if audio_data.ndim == 2:
80
  # FastRTC style [channels, samples]
 
93
  vad_buffer_size = self.vad_buffer.shape[0]
94
 
95
  def process_chunk(chunk_offset_s: float, vad_chunk: np.ndarray):
96
+ speech_prob = global_vad_process(vad_chunk)
97
 
98
  hop_s = self.hop_size / self.vad_sr
99
 
 
149
  self.sum_positive_s = 0.0
150
  self.silence_start_s = None
151
 
152
+ processed_samples = 0
153
  for chunk_pos in range(0, vad_buffer_size - self.hop_size, self.hop_size):
154
  processed_samples = chunk_pos + self.hop_size
155
  chunk_offset_s = (self.vad_buffer_offset + chunk_pos) / self.vad_sr
 
160
  self.vad_buffer_offset += processed_samples
161
 
162
 
163
+ def init_global_ten_vad(input_sample_rate: int = 24000):
164
+ """
165
+ Call this once at the start of the program to avoid latency on first use.
166
+ No-op if already initialized.
167
+ """
168
+ global global_ten_vad
169
+
170
+ require_warmup = False
171
+
172
+ with global_vad_lock:
173
+ if global_ten_vad is None:
174
+ global_ten_vad = TenVad()
175
+
176
+ require_warmup = True
177
+
178
+ if require_warmup:
179
+ print("[VAD] Initializing global TenVad...")
180
+
181
+ realtime_vad = RealtimeVAD(src_sr=input_sample_rate)
182
+
183
+ for _ in range(25): # Warmup with 1 second of silence
184
+ for _ in realtime_vad.process(np.zeros(960, dtype=np.int16)):
185
+ pass
186
+
187
+ print("[VAD] Global VAD initialized")
188
+
189
+
190
  type StreamerGenerator = Generator[fastrtc.tracks.EmitType, None, None]
191
  type StreamerFn = Callable[[tuple[int, np.ndarray], str], StreamerGenerator]
192
 
 
208
  self.realtime_vad = RealtimeVAD(src_sr=input_sample_rate)
209
  self.generator: StreamerGenerator | None = None
210
 
211
+ init_global_ten_vad()
212
+
213
  @override
214
  def emit(self) -> fastrtc.tracks.EmitType:
215
  if self.generator is None: