Helw150 commited on
Commit
340ea34
·
1 Parent(s): 87930ea
Files changed (3) hide show
  1. utils/assets/silero_vad.onnx +3 -0
  2. utils/snac_utils.py +146 -0
  3. utils/vad.py +290 -0
utils/assets/silero_vad.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:591f853590d11ddde2f2a54f9e7ccecb2533a8af7716330e8adfa6f3849787a9
3
+ size 1807524
utils/snac_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import numpy as np
4
+
5
+
6
+ class SnacConfig:
7
+ audio_vocab_size = 4096
8
+ padded_vocab_size = 4160
9
+ end_of_audio = 4097
10
+
11
+
12
+ snac_config = SnacConfig()
13
+
14
+
15
+ def get_time_str():
16
+ time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
17
+ return time_str
18
+
19
+
20
+ def layershift(input_id, layer, stride=4160, shift=152000):
21
+ return input_id + shift + layer * stride
22
+
23
+
24
+ def generate_audio_data(snac_tokens, snacmodel, device=None):
25
+ audio = reconstruct_tensors(snac_tokens, device)
26
+ with torch.inference_mode():
27
+ audio_hat = snacmodel.decode(audio)
28
+ audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
29
+ audio_data = audio_data.astype(np.int16)
30
+ audio_data = audio_data.tobytes()
31
+ return audio_data
32
+
33
+
34
+ def get_snac(list_output, index, nums_generate):
35
+
36
+ snac = []
37
+ start = index
38
+ for i in range(nums_generate):
39
+ snac.append("#")
40
+ for j in range(7):
41
+ snac.append(list_output[j][start - nums_generate - 5 + j + i])
42
+ return snac
43
+
44
+
45
+ def reconscruct_snac(output_list):
46
+ if len(output_list) == 8:
47
+ output_list = output_list[:-1]
48
+ output = []
49
+ for i in range(7):
50
+ output_list[i] = output_list[i][i + 1 :]
51
+ for i in range(len(output_list[-1])):
52
+ output.append("#")
53
+ for j in range(7):
54
+ output.append(output_list[j][i])
55
+ return output
56
+
57
+
58
+ def reconstruct_tensors(flattened_output, device=None):
59
+ """Reconstructs the list of tensors from the flattened output."""
60
+
61
+ if device is None:
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+
64
+ def count_elements_between_hashes(lst):
65
+ try:
66
+ # Find the index of the first '#'
67
+ first_index = lst.index("#")
68
+ # Find the index of the second '#' after the first
69
+ second_index = lst.index("#", first_index + 1)
70
+ # Count the elements between the two indices
71
+ return second_index - first_index - 1
72
+ except ValueError:
73
+ # Handle the case where there aren't enough '#' symbols
74
+ return "List does not contain two '#' symbols"
75
+
76
+ def remove_elements_before_hash(flattened_list):
77
+ try:
78
+ # Find the index of the first '#'
79
+ first_hash_index = flattened_list.index("#")
80
+ # Return the list starting from the first '#'
81
+ return flattened_list[first_hash_index:]
82
+ except ValueError:
83
+ # Handle the case where there is no '#'
84
+ return "List does not contain the symbol '#'"
85
+
86
+ def list_to_torch_tensor(tensor1):
87
+ # Convert the list to a torch tensor
88
+ tensor = torch.tensor(tensor1)
89
+ # Reshape the tensor to have size (1, n)
90
+ tensor = tensor.unsqueeze(0)
91
+ return tensor
92
+
93
+ flattened_output = remove_elements_before_hash(flattened_output)
94
+ codes = []
95
+ tensor1 = []
96
+ tensor2 = []
97
+ tensor3 = []
98
+ tensor4 = []
99
+
100
+ n_tensors = count_elements_between_hashes(flattened_output)
101
+ if n_tensors == 7:
102
+ for i in range(0, len(flattened_output), 8):
103
+
104
+ tensor1.append(flattened_output[i + 1])
105
+ tensor2.append(flattened_output[i + 2])
106
+ tensor3.append(flattened_output[i + 3])
107
+ tensor3.append(flattened_output[i + 4])
108
+
109
+ tensor2.append(flattened_output[i + 5])
110
+ tensor3.append(flattened_output[i + 6])
111
+ tensor3.append(flattened_output[i + 7])
112
+ codes = [
113
+ list_to_torch_tensor(tensor1).to(device),
114
+ list_to_torch_tensor(tensor2).to(device),
115
+ list_to_torch_tensor(tensor3).to(device),
116
+ ]
117
+
118
+ if n_tensors == 15:
119
+ for i in range(0, len(flattened_output), 16):
120
+
121
+ tensor1.append(flattened_output[i + 1])
122
+ tensor2.append(flattened_output[i + 2])
123
+ tensor3.append(flattened_output[i + 3])
124
+ tensor4.append(flattened_output[i + 4])
125
+ tensor4.append(flattened_output[i + 5])
126
+ tensor3.append(flattened_output[i + 6])
127
+ tensor4.append(flattened_output[i + 7])
128
+ tensor4.append(flattened_output[i + 8])
129
+
130
+ tensor2.append(flattened_output[i + 9])
131
+ tensor3.append(flattened_output[i + 10])
132
+ tensor4.append(flattened_output[i + 11])
133
+ tensor4.append(flattened_output[i + 12])
134
+ tensor3.append(flattened_output[i + 13])
135
+ tensor4.append(flattened_output[i + 14])
136
+ tensor4.append(flattened_output[i + 15])
137
+
138
+ codes = [
139
+ list_to_torch_tensor(tensor1).to(device),
140
+ list_to_torch_tensor(tensor2).to(device),
141
+ list_to_torch_tensor(tensor3).to(device),
142
+ list_to_torch_tensor(tensor4).to(device),
143
+ ]
144
+
145
+ return codes
146
+
utils/vad.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import functools
3
+ import os
4
+ import warnings
5
+
6
+ from typing import List, NamedTuple, Optional
7
+
8
+ import numpy as np
9
+
10
+
11
+ # The code below is adapted from https://github.com/snakers4/silero-vad.
12
+ class VadOptions(NamedTuple):
13
+ """VAD options.
14
+
15
+ Attributes:
16
+ threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
17
+ probabilities ABOVE this value are considered as SPEECH. It is better to tune this
18
+ parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
19
+ min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
20
+ max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
21
+ than max_speech_duration_s will be split at the timestamp of the last silence that
22
+ lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
23
+ split aggressively just before max_speech_duration_s.
24
+ min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
25
+ before separating it
26
+ window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
27
+ WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
28
+ Values other than these may affect model performance!!
29
+ speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
30
+ """
31
+
32
+ threshold: float = 0.5
33
+ min_speech_duration_ms: int = 250
34
+ max_speech_duration_s: float = float("inf")
35
+ min_silence_duration_ms: int = 2000
36
+ window_size_samples: int = 1024
37
+ speech_pad_ms: int = 400
38
+
39
+
40
+ def get_speech_timestamps(
41
+ audio: np.ndarray,
42
+ vad_options: Optional[VadOptions] = None,
43
+ **kwargs,
44
+ ) -> List[dict]:
45
+ """This method is used for splitting long audios into speech chunks using silero VAD.
46
+
47
+ Args:
48
+ audio: One dimensional float array.
49
+ vad_options: Options for VAD processing.
50
+ kwargs: VAD options passed as keyword arguments for backward compatibility.
51
+
52
+ Returns:
53
+ List of dicts containing begin and end samples of each speech chunk.
54
+ """
55
+ if vad_options is None:
56
+ vad_options = VadOptions(**kwargs)
57
+
58
+ threshold = vad_options.threshold
59
+ min_speech_duration_ms = vad_options.min_speech_duration_ms
60
+ max_speech_duration_s = vad_options.max_speech_duration_s
61
+ min_silence_duration_ms = vad_options.min_silence_duration_ms
62
+ window_size_samples = vad_options.window_size_samples
63
+ speech_pad_ms = vad_options.speech_pad_ms
64
+
65
+ if window_size_samples not in [512, 1024, 1536]:
66
+ warnings.warn(
67
+ "Unusual window_size_samples! Supported window_size_samples:\n"
68
+ " - [512, 1024, 1536] for 16000 sampling_rate"
69
+ )
70
+
71
+ sampling_rate = 16000
72
+ min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
73
+ speech_pad_samples = sampling_rate * speech_pad_ms / 1000
74
+ max_speech_samples = (
75
+ sampling_rate * max_speech_duration_s
76
+ - window_size_samples
77
+ - 2 * speech_pad_samples
78
+ )
79
+ min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
80
+ min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
81
+
82
+ audio_length_samples = len(audio)
83
+
84
+ model = get_vad_model()
85
+ state = model.get_initial_state(batch_size=1)
86
+
87
+ speech_probs = []
88
+ for current_start_sample in range(0, audio_length_samples, window_size_samples):
89
+ chunk = audio[current_start_sample : current_start_sample + window_size_samples]
90
+ if len(chunk) < window_size_samples:
91
+ chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
92
+ speech_prob, state = model(chunk, state, sampling_rate)
93
+ speech_probs.append(speech_prob)
94
+
95
+ triggered = False
96
+ speeches = []
97
+ current_speech = {}
98
+ neg_threshold = threshold - 0.15
99
+
100
+ # to save potential segment end (and tolerate some silence)
101
+ temp_end = 0
102
+ # to save potential segment limits in case of maximum segment size reached
103
+ prev_end = next_start = 0
104
+
105
+ for i, speech_prob in enumerate(speech_probs):
106
+ if (speech_prob >= threshold) and temp_end:
107
+ temp_end = 0
108
+ if next_start < prev_end:
109
+ next_start = window_size_samples * i
110
+
111
+ if (speech_prob >= threshold) and not triggered:
112
+ triggered = True
113
+ current_speech["start"] = window_size_samples * i
114
+ continue
115
+
116
+ if (
117
+ triggered
118
+ and (window_size_samples * i) - current_speech["start"] > max_speech_samples
119
+ ):
120
+ if prev_end:
121
+ current_speech["end"] = prev_end
122
+ speeches.append(current_speech)
123
+ current_speech = {}
124
+ # previously reached silence (< neg_thres) and is still not speech (< thres)
125
+ if next_start < prev_end:
126
+ triggered = False
127
+ else:
128
+ current_speech["start"] = next_start
129
+ prev_end = next_start = temp_end = 0
130
+ else:
131
+ current_speech["end"] = window_size_samples * i
132
+ speeches.append(current_speech)
133
+ current_speech = {}
134
+ prev_end = next_start = temp_end = 0
135
+ triggered = False
136
+ continue
137
+
138
+ if (speech_prob < neg_threshold) and triggered:
139
+ if not temp_end:
140
+ temp_end = window_size_samples * i
141
+ # condition to avoid cutting in very short silence
142
+ if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
143
+ prev_end = temp_end
144
+ if (window_size_samples * i) - temp_end < min_silence_samples:
145
+ continue
146
+ else:
147
+ current_speech["end"] = temp_end
148
+ if (
149
+ current_speech["end"] - current_speech["start"]
150
+ ) > min_speech_samples:
151
+ speeches.append(current_speech)
152
+ current_speech = {}
153
+ prev_end = next_start = temp_end = 0
154
+ triggered = False
155
+ continue
156
+
157
+ if (
158
+ current_speech
159
+ and (audio_length_samples - current_speech["start"]) > min_speech_samples
160
+ ):
161
+ current_speech["end"] = audio_length_samples
162
+ speeches.append(current_speech)
163
+
164
+ for i, speech in enumerate(speeches):
165
+ if i == 0:
166
+ speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
167
+ if i != len(speeches) - 1:
168
+ silence_duration = speeches[i + 1]["start"] - speech["end"]
169
+ if silence_duration < 2 * speech_pad_samples:
170
+ speech["end"] += int(silence_duration // 2)
171
+ speeches[i + 1]["start"] = int(
172
+ max(0, speeches[i + 1]["start"] - silence_duration // 2)
173
+ )
174
+ else:
175
+ speech["end"] = int(
176
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
177
+ )
178
+ speeches[i + 1]["start"] = int(
179
+ max(0, speeches[i + 1]["start"] - speech_pad_samples)
180
+ )
181
+ else:
182
+ speech["end"] = int(
183
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
184
+ )
185
+
186
+ return speeches
187
+
188
+
189
+ def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
190
+ """Collects and concatenates audio chunks."""
191
+ if not chunks:
192
+ return np.array([], dtype=np.float32)
193
+
194
+ return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
195
+
196
+
197
+ class SpeechTimestampsMap:
198
+ """Helper class to restore original speech timestamps."""
199
+
200
+ def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2):
201
+ self.sampling_rate = sampling_rate
202
+ self.time_precision = time_precision
203
+ self.chunk_end_sample = []
204
+ self.total_silence_before = []
205
+
206
+ previous_end = 0
207
+ silent_samples = 0
208
+
209
+ for chunk in chunks:
210
+ silent_samples += chunk["start"] - previous_end
211
+ previous_end = chunk["end"]
212
+
213
+ self.chunk_end_sample.append(chunk["end"] - silent_samples)
214
+ self.total_silence_before.append(silent_samples / sampling_rate)
215
+
216
+ def get_original_time(
217
+ self,
218
+ time: float,
219
+ chunk_index: Optional[int] = None,
220
+ ) -> float:
221
+ if chunk_index is None:
222
+ chunk_index = self.get_chunk_index(time)
223
+
224
+ total_silence_before = self.total_silence_before[chunk_index]
225
+ return round(total_silence_before + time, self.time_precision)
226
+
227
+ def get_chunk_index(self, time: float) -> int:
228
+ sample = int(time * self.sampling_rate)
229
+ return min(
230
+ bisect.bisect(self.chunk_end_sample, sample),
231
+ len(self.chunk_end_sample) - 1,
232
+ )
233
+
234
+
235
+ @functools.lru_cache
236
+ def get_vad_model():
237
+ """Returns the VAD model instance."""
238
+ asset_dir = os.path.join(os.path.dirname(__file__), "assets")
239
+ path = os.path.join(asset_dir, "silero_vad.onnx")
240
+ return SileroVADModel(path)
241
+
242
+
243
+ class SileroVADModel:
244
+ def __init__(self, path):
245
+ try:
246
+ import onnxruntime
247
+ except ImportError as e:
248
+ raise RuntimeError(
249
+ "Applying the VAD filter requires the onnxruntime package"
250
+ ) from e
251
+
252
+ opts = onnxruntime.SessionOptions()
253
+ opts.inter_op_num_threads = 1
254
+ opts.intra_op_num_threads = 1
255
+ opts.log_severity_level = 4
256
+
257
+ self.session = onnxruntime.InferenceSession(
258
+ path,
259
+ providers=["CPUExecutionProvider"],
260
+ sess_options=opts,
261
+ )
262
+
263
+ def get_initial_state(self, batch_size: int):
264
+ h = np.zeros((2, batch_size, 64), dtype=np.float32)
265
+ c = np.zeros((2, batch_size, 64), dtype=np.float32)
266
+ return h, c
267
+
268
+ def __call__(self, x, state, sr: int):
269
+ if len(x.shape) == 1:
270
+ x = np.expand_dims(x, 0)
271
+ if len(x.shape) > 2:
272
+ raise ValueError(
273
+ f"Too many dimensions for input audio chunk {len(x.shape)}"
274
+ )
275
+ if sr / x.shape[1] > 31.25:
276
+ raise ValueError("Input audio chunk is too short")
277
+
278
+ h, c = state
279
+
280
+ ort_inputs = {
281
+ "input": x,
282
+ "h": h,
283
+ "c": c,
284
+ "sr": np.array(sr, dtype="int64"),
285
+ }
286
+
287
+ out, h, c = self.session.run(None, ort_inputs)
288
+ state = (h, c)
289
+
290
+ return out, state