aadnk commited on
Commit
f288ceb
1 Parent(s): aa22372

Add support for selecting a VAD

Browse files

The VAD (Voice Activity Detector) is used to detect time segments
where there is speech (more than 250ms), and only run whisper on
these segments. This prevents Whisper getting stuck in a loop, producing
the same sentence over and over (which usually happens after a long
continuous sequence of no speech).

A secondary benefit is that the time synchronization issues that sometimes are
present, will not carry over each detected time segment.

One slight issue, however, may be if these time sequences are very short,
and prevent Whisper from using previous text as prompt for context. To
mitigate this somewhat, the detected time segments with speech are
padded by 1 second before, and 4 seconds after, and then merged
if they overlap. Finally, the VAO model's threshold is set to 30%
instead of 50%.

Files changed (2) hide show
  1. app.py +14 -3
  2. vad.py +185 -0
app.py CHANGED
@@ -15,6 +15,7 @@ import gradio as gr
15
  from download import ExceededMaximumDuration, downloadUrl
16
 
17
  from utils import slugify, write_srt, write_vtt
 
18
 
19
  # Limitations (set to -1 to disable)
20
  DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
@@ -49,9 +50,10 @@ model_cache = dict()
49
 
50
  class UI:
51
  def __init__(self, inputAudioMaxDuration):
 
52
  self.inputAudioMaxDuration = inputAudioMaxDuration
53
 
54
- def transcribeFile(self, modelName, languageName, urlData, uploadFile, microphoneData, task):
55
  try:
56
  source, sourceName = self.getSource(urlData, uploadFile, microphoneData)
57
 
@@ -66,7 +68,14 @@ class UI:
66
  model_cache[selectedModel] = model
67
 
68
  # The results
69
- result = model.transcribe(source, language=selectedLanguage, task=task)
 
 
 
 
 
 
 
70
 
71
  text = result["text"]
72
 
@@ -154,7 +163,8 @@ def createUi(inputAudioMaxDuration, share=False, server_name: str = None):
154
  ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
155
  ui_description += " as well as speech translation and language identification. "
156
 
157
- ui_description += "\n\n" + "Note: You can upload more audio (and even video) types by changing to All Files (*.*) in the file selector."
 
158
 
159
  if inputAudioMaxDuration > 0:
160
  ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"
@@ -166,6 +176,7 @@ def createUi(inputAudioMaxDuration, share=False, server_name: str = None):
166
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
167
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
168
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
 
169
  ], outputs=[
170
  gr.File(label="Download"),
171
  gr.Text(label="Transcription"),
 
15
  from download import ExceededMaximumDuration, downloadUrl
16
 
17
  from utils import slugify, write_srt, write_vtt
18
+ from vad import VadTranscription
19
 
20
  # Limitations (set to -1 to disable)
21
  DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
 
50
 
51
  class UI:
52
  def __init__(self, inputAudioMaxDuration):
53
+ self.vad_model = None
54
  self.inputAudioMaxDuration = inputAudioMaxDuration
55
 
56
+ def transcribeFile(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad):
57
  try:
58
  source, sourceName = self.getSource(urlData, uploadFile, microphoneData)
59
 
 
68
  model_cache[selectedModel] = model
69
 
70
  # The results
71
+ if (vad == 'silero-vad'):
72
+ # Use Silero VAD
73
+ if (self.vad_model is None):
74
+ self.vad_model = VadTranscription()
75
+ result = self.vad_model.transcribe(source, lambda audio : model.transcribe(audio, language=selectedLanguage, task=task))
76
+ else:
77
+ # Default VAD
78
+ result = model.transcribe(source, language=selectedLanguage, task=task)
79
 
80
  text = result["text"]
81
 
 
163
  ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
164
  ui_description += " as well as speech translation and language identification. "
165
 
166
+ ui_description += "\n\n" + "Note: You can upload more audio (and even video) types by changing to All Files (*.*) in the file selector. For longer audio files (>10 minutes), "
167
+ ui_description += "it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
168
 
169
  if inputAudioMaxDuration > 0:
170
  ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"
 
176
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
177
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
178
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
179
+ gr.Dropdown(choices=["none", "silero-vad"], label="VAD"),
180
  ], outputs=[
181
  gr.File(label="Download"),
182
  gr.Text(label="Transcription"),
vad.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ from dis import dis
3
+ from typing import Any, Iterator, List, Dict
4
+
5
+ from pprint import pprint
6
+ import torch
7
+
8
+ import ffmpeg
9
+ import numpy as np
10
+
11
+ SPEECH_TRESHOLD = 0.3
12
+ MAX_SILENT_PERIOD = 10 # seconds
13
+
14
+ SEGMENT_PADDING_LEFT = 1 # Start detected text segment early
15
+ SEGMENT_PADDING_RIGHT = 4 # End detected segments late
16
+
17
+ def load_audio(file: str, sample_rate: int = 16000,
18
+ start_time: str = None, duration: str = None):
19
+ """
20
+ Open an audio file and read as mono waveform, resampling as necessary
21
+
22
+ Parameters
23
+ ----------
24
+ file: str
25
+ The audio file to open
26
+
27
+ sr: int
28
+ The sample rate to resample the audio if necessary
29
+
30
+ start_time: str
31
+ The start time, using the standard FFMPEG time duration syntax, or None to disable.
32
+
33
+ duration: str
34
+ The duration, using the standard FFMPEG time duration syntax, or None to disable.
35
+
36
+ Returns
37
+ -------
38
+ A NumPy array containing the audio waveform, in float32 dtype.
39
+ """
40
+ try:
41
+ inputArgs = {'threads': 0}
42
+
43
+ if (start_time is not None):
44
+ inputArgs['ss'] = start_time
45
+ if (duration is not None):
46
+ inputArgs['t'] = duration
47
+
48
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
49
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
50
+ out, _ = (
51
+ ffmpeg.input(file, **inputArgs)
52
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
53
+ .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
54
+ )
55
+ except ffmpeg.Error as e:
56
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
57
+
58
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
59
+
60
+ class VadTranscription:
61
+ def __init__(self):
62
+ self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
63
+
64
+ (self.get_speech_timestamps, _, _, _, _) = utils
65
+
66
+ def transcribe(self, audio: str, whisperCallable):
67
+ SAMPLING_RATE = 16000
68
+ wav = load_audio(audio, sample_rate=SAMPLING_RATE)
69
+
70
+ # get speech timestamps from full audio file
71
+ sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=SAMPLING_RATE, threshold=SPEECH_TRESHOLD)
72
+ seconds_timestamps = self.convert_seconds(sample_timestamps, sampling_rate=SAMPLING_RATE)
73
+
74
+ padded = self.pad_timestamps(seconds_timestamps, SEGMENT_PADDING_LEFT, SEGMENT_PADDING_RIGHT)
75
+ merged = self.merge_timestamps(padded, MAX_SILENT_PERIOD)
76
+
77
+ print("Timestamps:")
78
+ pprint(merged)
79
+
80
+ result = {
81
+ 'text': "",
82
+ 'segments': [],
83
+ 'language': ""
84
+ }
85
+ languageCounter = Counter()
86
+
87
+ # For each time segment, run whisper
88
+ for segment in merged:
89
+ segment_start = segment['start']
90
+ segment_duration = segment['end'] - segment_start
91
+
92
+ segment_audio = load_audio(audio, sample_rate=SAMPLING_RATE, start_time = str(segment_start) + "s", duration = str(segment_duration) + "s")
93
+
94
+ print("Running whisper on " + str(segment_start) + ", duration: " + str(segment_duration))
95
+ segment_result = whisperCallable(segment_audio)
96
+ adjusted_segments = self.adjust_whisper_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
97
+
98
+ # Append to output
99
+ result['text'] += segment_result['text']
100
+ result['segments'].extend(adjusted_segments)
101
+
102
+ # Increment detected language
103
+ languageCounter[segment_result['language']] += 1
104
+
105
+ if len(languageCounter) > 0:
106
+ result['language'] = languageCounter.most_common(1)[0][0]
107
+
108
+ return result
109
+
110
+ def adjust_whisper_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
111
+ result = []
112
+
113
+ for segment in segments:
114
+ segment_start = float(segment['start'])
115
+ segment_end = float(segment['end'])
116
+
117
+ # Filter segments?
118
+ if (max_source_time is not None):
119
+ if (segment_start > max_source_time):
120
+ continue
121
+ segment_end = min(max_source_time, segment_end)
122
+
123
+ new_segment = segment.copy()
124
+
125
+ # Add to start and end
126
+ new_segment['start'] = segment_start + adjust_seconds
127
+ new_segment['end'] = segment_end + adjust_seconds
128
+ result.append(new_segment)
129
+ return result
130
+
131
+ def pad_timestamps(self, timestamps: List[Dict[str, Any]], padding_left: float, padding_right: float):
132
+ result = []
133
+
134
+ for entry in timestamps:
135
+ segment_start = entry['start']
136
+ segment_end = entry['end']
137
+
138
+ if padding_left is not None:
139
+ segment_start = max(0, segment_start - padding_left)
140
+ if padding_right is not None:
141
+ segment_end = segment_end + padding_right
142
+
143
+ result.append({ 'start': segment_start, 'end': segment_end })
144
+
145
+ return result
146
+
147
+ def merge_timestamps(self, timestamps: List[Dict[str, Any]], max_distance: float):
148
+ result = []
149
+ current_entry = None
150
+
151
+ for entry in timestamps:
152
+ if current_entry is None:
153
+ current_entry = entry
154
+ continue
155
+
156
+ # Get distance to the previous entry
157
+ distance = entry['start'] - current_entry['end']
158
+
159
+ if distance <= max_distance:
160
+ # Merge
161
+ current_entry['end'] = entry['end']
162
+ else:
163
+ # Output current entry
164
+ result.append(current_entry)
165
+ current_entry = entry
166
+
167
+ # Add final entry
168
+ if current_entry is not None:
169
+ result.append(current_entry)
170
+
171
+ return result
172
+
173
+ def convert_seconds(self, timestamps: List[Dict[str, Any]], sampling_rate: int):
174
+ result = []
175
+
176
+ for entry in timestamps:
177
+ start = entry['start']
178
+ end = entry['end']
179
+
180
+ result.append({
181
+ 'start': start / sampling_rate,
182
+ 'end': end / sampling_rate
183
+ })
184
+ return result
185
+