aadnk commited on
Commit
bbbf06e
1 Parent(s): 0d8d833

Add a simple "VAD" alternative - periodic VAD

Browse files

This is a very simple alternative to using Silero Vad, where every
5 minutes (0:00 - 5:00, 5:00 - 10:00, etc.) is marked as a
"speech" segment, and whisper is then run on each
segment individually.

The upside is that all potential dialogue will be sent to
Whisper for transcribing, but it is also more likely to
get into an infinite sentence loop (although limited to each
5 minute segment).

Creating an artifical break every 5 minutes may also break
up a sentence accidentally, causing the sentence (or a word) to be
incorrectly transcribed.

Files changed (3) hide show
  1. app.py +13 -5
  2. tests/vad_test.py +66 -0
  3. vad.py +119 -58
app.py CHANGED
@@ -15,7 +15,7 @@ import gradio as gr
15
  from download import ExceededMaximumDuration, downloadUrl
16
 
17
  from utils import slugify, write_srt, write_vtt
18
- from vad import VadTranscription
19
 
20
  # Limitations (set to -1 to disable)
21
  DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
@@ -67,15 +67,23 @@ class UI:
67
  model = whisper.load_model(selectedModel)
68
  model_cache[selectedModel] = model
69
 
 
 
 
70
  # The results
71
  if (vad == 'silero-vad'):
72
  # Use Silero VAD
73
  if (self.vad_model is None):
74
- self.vad_model = VadTranscription()
75
- result = self.vad_model.transcribe(source, lambda audio : model.transcribe(audio, language=selectedLanguage, task=task))
 
 
 
 
 
76
  else:
77
  # Default VAD
78
- result = model.transcribe(source, language=selectedLanguage, task=task)
79
 
80
  text = result["text"]
81
 
@@ -176,7 +184,7 @@ def createUi(inputAudioMaxDuration, share=False, server_name: str = None):
176
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
177
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
178
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
179
- gr.Dropdown(choices=["none", "silero-vad"], label="VAD"),
180
  ], outputs=[
181
  gr.File(label="Download"),
182
  gr.Text(label="Transcription"),
 
15
  from download import ExceededMaximumDuration, downloadUrl
16
 
17
  from utils import slugify, write_srt, write_vtt
18
+ from vad import VadPeriodicTranscription, VadSileroTranscription
19
 
20
  # Limitations (set to -1 to disable)
21
  DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
 
67
  model = whisper.load_model(selectedModel)
68
  model_cache[selectedModel] = model
69
 
70
+ # Callable for processing an audio file
71
+ whisperCallable = lambda audio : model.transcribe(audio, language=selectedLanguage, task=task)
72
+
73
  # The results
74
  if (vad == 'silero-vad'):
75
  # Use Silero VAD
76
  if (self.vad_model is None):
77
+ self.vad_model = VadSileroTranscription()
78
+ result = self.vad_model.transcribe(source, whisperCallable)
79
+ elif (vad == 'periodic-vad'):
80
+ # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
81
+ # it may create a break in the middle of a sentence, causing some artifacts.
82
+ periodic_vad = VadPeriodicTranscription(periodic_duration=60 * 5)
83
+ result = periodic_vad.transcribe(source, whisperCallable)
84
  else:
85
  # Default VAD
86
+ result = whisperCallable(source)
87
 
88
  text = result["text"]
89
 
 
184
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
185
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
186
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
187
+ gr.Dropdown(choices=["none", "silero-vad", "periodic-vad"], label="VAD"),
188
  ], outputs=[
189
  gr.File(label="Download"),
190
  gr.Text(label="Transcription"),
tests/vad_test.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ import unittest
3
+ import numpy as np
4
+ import sys
5
+
6
+ sys.path.append('../whisper-webui')
7
+
8
+ from vad import AbstractTranscription, VadSileroTranscription
9
+
10
+ class TestVad(unittest.TestCase):
11
+ def __init__(self, *args, **kwargs):
12
+ super(TestVad, self).__init__(*args, **kwargs)
13
+ self.transcribe_calls = []
14
+
15
+ def test_transcript(self):
16
+ mock = MockVadTranscription()
17
+
18
+ self.transcribe_calls.clear()
19
+ result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
20
+
21
+ self.assertListEqual(self.transcribe_calls, [
22
+ [30, 30],
23
+ [100, 100]
24
+ ])
25
+
26
+ self.assertListEqual(result['segments'],
27
+ [{'end': 50.0, 'start': 40.0, 'text': 'Hello world'},
28
+ {'end': 120.0, 'start': 110.0, 'text': 'Hello world'}]
29
+ )
30
+
31
+ def transcribe_segments(self, segment):
32
+ self.transcribe_calls.append(segment.tolist())
33
+
34
+ # Dummy text
35
+ return {
36
+ 'text': "Hello world ",
37
+ 'segments': [
38
+ {
39
+ "start": 10.0,
40
+ "end": 20.0,
41
+ "text": "Hello world "
42
+ }
43
+ ],
44
+ 'language': ""
45
+ }
46
+
47
+ class MockVadTranscription(AbstractTranscription):
48
+ def __init__(self):
49
+ super().__init__()
50
+
51
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
52
+ start_time_seconds = float(start_time.removesuffix("s"))
53
+ duration_seconds = float(duration.removesuffix("s"))
54
+
55
+ # For mocking, this just returns a simple numppy array
56
+ return np.array([start_time_seconds, duration_seconds], dtype=np.float64)
57
+
58
+ def get_transcribe_timestamps(self, audio: str):
59
+ result = []
60
+
61
+ result.append( { 'start': 30, 'end': 60 } )
62
+ result.append( { 'start': 100, 'end': 200 } )
63
+ return result
64
+
65
+ if __name__ == '__main__':
66
+ unittest.main()
vad.py CHANGED
@@ -1,6 +1,7 @@
 
1
  from collections import Counter
2
  from dis import dis
3
- from typing import Any, Iterator, List, Dict
4
 
5
  from pprint import pprint
6
  import torch
@@ -8,71 +9,45 @@ import torch
8
  import ffmpeg
9
  import numpy as np
10
 
 
11
  SPEECH_TRESHOLD = 0.3
12
  MAX_SILENT_PERIOD = 10 # seconds
13
-
14
  SEGMENT_PADDING_LEFT = 1 # Start detected text segment early
15
  SEGMENT_PADDING_RIGHT = 4 # End detected segments late
16
 
17
- def load_audio(file: str, sample_rate: int = 16000,
18
- start_time: str = None, duration: str = None):
19
- """
20
- Open an audio file and read as mono waveform, resampling as necessary
21
-
22
- Parameters
23
- ----------
24
- file: str
25
- The audio file to open
26
-
27
- sr: int
28
- The sample rate to resample the audio if necessary
29
-
30
- start_time: str
31
- The start time, using the standard FFMPEG time duration syntax, or None to disable.
32
-
33
- duration: str
34
- The duration, using the standard FFMPEG time duration syntax, or None to disable.
35
-
36
- Returns
37
- -------
38
- A NumPy array containing the audio waveform, in float32 dtype.
39
- """
40
- try:
41
- inputArgs = {'threads': 0}
42
-
43
- if (start_time is not None):
44
- inputArgs['ss'] = start_time
45
- if (duration is not None):
46
- inputArgs['t'] = duration
47
 
48
- # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
49
- # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
50
- out, _ = (
51
- ffmpeg.input(file, **inputArgs)
52
- .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
53
- .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
54
- )
55
- except ffmpeg.Error as e:
56
- raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
57
 
58
- return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
 
59
 
60
- class VadTranscription:
61
- def __init__(self):
62
- self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
 
63
 
64
- (self.get_speech_timestamps, _, _, _, _) = utils
 
 
 
65
 
66
- def transcribe(self, audio: str, whisperCallable):
67
- SAMPLING_RATE = 16000
68
- wav = load_audio(audio, sample_rate=SAMPLING_RATE)
 
 
69
 
 
70
  # get speech timestamps from full audio file
71
- sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=SAMPLING_RATE, threshold=SPEECH_TRESHOLD)
72
- seconds_timestamps = self.convert_seconds(sample_timestamps, sampling_rate=SAMPLING_RATE)
73
 
74
- padded = self.pad_timestamps(seconds_timestamps, SEGMENT_PADDING_LEFT, SEGMENT_PADDING_RIGHT)
75
- merged = self.merge_timestamps(padded, MAX_SILENT_PERIOD)
76
 
77
  print("Timestamps:")
78
  pprint(merged)
@@ -89,7 +64,7 @@ class VadTranscription:
89
  segment_start = segment['start']
90
  segment_duration = segment['end'] - segment_start
91
 
92
- segment_audio = load_audio(audio, sample_rate=SAMPLING_RATE, start_time = str(segment_start) + "s", duration = str(segment_duration) + "s")
93
 
94
  print("Running whisper on " + str(segment_start) + ", duration: " + str(segment_duration))
95
  segment_result = whisperCallable(segment_audio)
@@ -145,6 +120,9 @@ class VadTranscription:
145
  return result
146
 
147
  def merge_timestamps(self, timestamps: List[Dict[str, Any]], max_distance: float):
 
 
 
148
  result = []
149
  current_entry = None
150
 
@@ -170,7 +148,7 @@ class VadTranscription:
170
 
171
  return result
172
 
173
- def convert_seconds(self, timestamps: List[Dict[str, Any]], sampling_rate: int):
174
  result = []
175
 
176
  for entry in timestamps:
@@ -178,8 +156,91 @@ class VadTranscription:
178
  end = entry['end']
179
 
180
  result.append({
181
- 'start': start / sampling_rate,
182
- 'end': end / sampling_rate
183
  })
184
  return result
185
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
  from collections import Counter
3
  from dis import dis
4
+ from typing import Any, Callable, Iterator, List, Dict, Union
5
 
6
  from pprint import pprint
7
  import torch
 
9
  import ffmpeg
10
  import numpy as np
11
 
12
+ # Defaults
13
  SPEECH_TRESHOLD = 0.3
14
  MAX_SILENT_PERIOD = 10 # seconds
 
15
  SEGMENT_PADDING_LEFT = 1 # Start detected text segment early
16
  SEGMENT_PADDING_RIGHT = 4 # End detected segments late
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ class AbstractTranscription(ABC):
20
+ def __init__(self, segment_padding_left: int = None, segment_padding_right = None, max_silent_period: int = None):
21
+ self.sampling_rate = 16000
22
+ self.segment_padding_left = segment_padding_left
23
+ self.segment_padding_right = segment_padding_right
24
+ self.max_silent_period = max_silent_period
 
 
 
25
 
26
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
27
+ return load_audio(str, self.sampling_rate, start_time, duration)
28
 
29
+ @abstractmethod
30
+ def get_transcribe_timestamps(self, audio: str):
31
+ """
32
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method.
33
 
34
+ Parameters
35
+ ----------
36
+ audio: str
37
+ The audio file.
38
 
39
+ Returns
40
+ -------
41
+ A list of start and end timestamps, in fractional seconds.
42
+ """
43
+ return
44
 
45
+ def transcribe(self, audio: str, whisperCallable: Callable[[Union[str, np.ndarray, torch.Tensor]], dict[str, Union[dict, Any]]]):
46
  # get speech timestamps from full audio file
47
+ seconds_timestamps = self.get_transcribe_timestamps(audio)
 
48
 
49
+ padded = self.pad_timestamps(seconds_timestamps, self.segment_padding_left, self.segment_padding_right)
50
+ merged = self.merge_timestamps(padded, self.max_silent_period)
51
 
52
  print("Timestamps:")
53
  pprint(merged)
 
64
  segment_start = segment['start']
65
  segment_duration = segment['end'] - segment_start
66
 
67
+ segment_audio = self.get_audio_segment(audio, start_time = str(segment_start) + "s", duration = str(segment_duration) + "s")
68
 
69
  print("Running whisper on " + str(segment_start) + ", duration: " + str(segment_duration))
70
  segment_result = whisperCallable(segment_audio)
 
120
  return result
121
 
122
  def merge_timestamps(self, timestamps: List[Dict[str, Any]], max_distance: float):
123
+ if max_distance is None:
124
+ return timestamps
125
+
126
  result = []
127
  current_entry = None
128
 
 
148
 
149
  return result
150
 
151
+ def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
152
  result = []
153
 
154
  for entry in timestamps:
 
156
  end = entry['end']
157
 
158
  result.append({
159
+ 'start': start * factor,
160
+ 'end': end * factor
161
  })
162
  return result
163
+
164
+ class VadSileroTranscription(AbstractTranscription):
165
+ def __init__(self):
166
+ super().__init__(SEGMENT_PADDING_LEFT, SEGMENT_PADDING_RIGHT, MAX_SILENT_PERIOD)
167
+ self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
168
+
169
+ (self.get_speech_timestamps, _, _, _, _) = utils
170
+
171
+ def get_transcribe_timestamps(self, audio: str):
172
+ wav = self.get_audio_segment(audio)
173
+
174
+ sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=self.sampling_rate, threshold=SPEECH_TRESHOLD)
175
+ seconds_timestamps = self.multiply_timestamps(sample_timestamps, factor=1 / self.sampling_rate)
176
+
177
+ return seconds_timestamps
178
+
179
+ # A very simple VAD that just marks every N seconds as speech
180
+ class VadPeriodicTranscription(AbstractTranscription):
181
+ def __init__(self, periodic_duration: int):
182
+ super().__init__()
183
+ self.periodic_duration = periodic_duration
184
+
185
+ def get_transcribe_timestamps(self, audio: str):
186
+ # Get duration in seconds
187
+ audio_duration = float(ffmpeg.probe(audio)["format"]["duration"])
188
+ result = []
189
+
190
+ # Generate a timestamp every N seconds
191
+ start_timestamp = 0
192
+
193
+ while (start_timestamp < audio_duration):
194
+ end_timestamp = min(start_timestamp + self.periodic_duration, audio_duration)
195
+ segment_duration = end_timestamp - start_timestamp
196
+
197
+ # Minimum duration is 1 second
198
+ if (segment_duration >= 1):
199
+ result.append( { 'start': start_timestamp, 'end': end_timestamp } )
200
+
201
+ start_timestamp = end_timestamp
202
+
203
+ return result
204
+
205
+ def load_audio(file: str, sample_rate: int = 16000,
206
+ start_time: str = None, duration: str = None):
207
+ """
208
+ Open an audio file and read as mono waveform, resampling as necessary
209
+
210
+ Parameters
211
+ ----------
212
+ file: str
213
+ The audio file to open
214
+
215
+ sr: int
216
+ The sample rate to resample the audio if necessary
217
+
218
+ start_time: str
219
+ The start time, using the standard FFMPEG time duration syntax, or None to disable.
220
+
221
+ duration: str
222
+ The duration, using the standard FFMPEG time duration syntax, or None to disable.
223
+
224
+ Returns
225
+ -------
226
+ A NumPy array containing the audio waveform, in float32 dtype.
227
+ """
228
+ try:
229
+ inputArgs = {'threads': 0}
230
+
231
+ if (start_time is not None):
232
+ inputArgs['ss'] = start_time
233
+ if (duration is not None):
234
+ inputArgs['t'] = duration
235
+
236
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
237
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
238
+ out, _ = (
239
+ ffmpeg.input(file, **inputArgs)
240
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
241
+ .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
242
+ )
243
+ except ffmpeg.Error as e:
244
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
245
+
246
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0