Spaces:
Sleeping
Add support for selecting a VAD
Browse filesThe VAD (Voice Activity Detector) is used to detect time segments
where there is speech (more than 250ms), and only run whisper on
these segments. This prevents Whisper getting stuck in a loop, producing
the same sentence over and over (which usually happens after a long
continuous sequence of no speech).
A secondary benefit is that the time synchronization issues that sometimes are
present, will not carry over each detected time segment.
One slight issue, however, may be if these time sequences are very short,
and prevent Whisper from using previous text as prompt for context. To
mitigate this somewhat, the detected time segments with speech are
padded by 1 second before, and 4 seconds after, and then merged
if they overlap. Finally, the VAO model's threshold is set to 30%
instead of 50%.
@@ -15,6 +15,7 @@ import gradio as gr
|
|
15 |
from download import ExceededMaximumDuration, downloadUrl
|
16 |
|
17 |
from utils import slugify, write_srt, write_vtt
|
|
|
18 |
|
19 |
# Limitations (set to -1 to disable)
|
20 |
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
|
@@ -49,9 +50,10 @@ model_cache = dict()
|
|
49 |
|
50 |
class UI:
|
51 |
def __init__(self, inputAudioMaxDuration):
|
|
|
52 |
self.inputAudioMaxDuration = inputAudioMaxDuration
|
53 |
|
54 |
-
def transcribeFile(self, modelName, languageName, urlData, uploadFile, microphoneData, task):
|
55 |
try:
|
56 |
source, sourceName = self.getSource(urlData, uploadFile, microphoneData)
|
57 |
|
@@ -66,7 +68,14 @@ class UI:
|
|
66 |
model_cache[selectedModel] = model
|
67 |
|
68 |
# The results
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
text = result["text"]
|
72 |
|
@@ -154,7 +163,8 @@ def createUi(inputAudioMaxDuration, share=False, server_name: str = None):
|
|
154 |
ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
|
155 |
ui_description += " as well as speech translation and language identification. "
|
156 |
|
157 |
-
ui_description += "\n\n" + "Note: You can upload more audio (and even video) types by changing to All Files (*.*) in the file selector."
|
|
|
158 |
|
159 |
if inputAudioMaxDuration > 0:
|
160 |
ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"
|
@@ -166,6 +176,7 @@ def createUi(inputAudioMaxDuration, share=False, server_name: str = None):
|
|
166 |
gr.Audio(source="upload", type="filepath", label="Upload Audio"),
|
167 |
gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
|
168 |
gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
|
|
|
169 |
], outputs=[
|
170 |
gr.File(label="Download"),
|
171 |
gr.Text(label="Transcription"),
|
|
|
15 |
from download import ExceededMaximumDuration, downloadUrl
|
16 |
|
17 |
from utils import slugify, write_srt, write_vtt
|
18 |
+
from vad import VadTranscription
|
19 |
|
20 |
# Limitations (set to -1 to disable)
|
21 |
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
|
|
|
50 |
|
51 |
class UI:
|
52 |
def __init__(self, inputAudioMaxDuration):
|
53 |
+
self.vad_model = None
|
54 |
self.inputAudioMaxDuration = inputAudioMaxDuration
|
55 |
|
56 |
+
def transcribeFile(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad):
|
57 |
try:
|
58 |
source, sourceName = self.getSource(urlData, uploadFile, microphoneData)
|
59 |
|
|
|
68 |
model_cache[selectedModel] = model
|
69 |
|
70 |
# The results
|
71 |
+
if (vad == 'silero-vad'):
|
72 |
+
# Use Silero VAD
|
73 |
+
if (self.vad_model is None):
|
74 |
+
self.vad_model = VadTranscription()
|
75 |
+
result = self.vad_model.transcribe(source, lambda audio : model.transcribe(audio, language=selectedLanguage, task=task))
|
76 |
+
else:
|
77 |
+
# Default VAD
|
78 |
+
result = model.transcribe(source, language=selectedLanguage, task=task)
|
79 |
|
80 |
text = result["text"]
|
81 |
|
|
|
163 |
ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
|
164 |
ui_description += " as well as speech translation and language identification. "
|
165 |
|
166 |
+
ui_description += "\n\n" + "Note: You can upload more audio (and even video) types by changing to All Files (*.*) in the file selector. For longer audio files (>10 minutes), "
|
167 |
+
ui_description += "it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
|
168 |
|
169 |
if inputAudioMaxDuration > 0:
|
170 |
ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"
|
|
|
176 |
gr.Audio(source="upload", type="filepath", label="Upload Audio"),
|
177 |
gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
|
178 |
gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
|
179 |
+
gr.Dropdown(choices=["none", "silero-vad"], label="VAD"),
|
180 |
], outputs=[
|
181 |
gr.File(label="Download"),
|
182 |
gr.Text(label="Transcription"),
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import Counter
|
2 |
+
from dis import dis
|
3 |
+
from typing import Any, Iterator, List, Dict
|
4 |
+
|
5 |
+
from pprint import pprint
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import ffmpeg
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
SPEECH_TRESHOLD = 0.3
|
12 |
+
MAX_SILENT_PERIOD = 10 # seconds
|
13 |
+
|
14 |
+
SEGMENT_PADDING_LEFT = 1 # Start detected text segment early
|
15 |
+
SEGMENT_PADDING_RIGHT = 4 # End detected segments late
|
16 |
+
|
17 |
+
def load_audio(file: str, sample_rate: int = 16000,
|
18 |
+
start_time: str = None, duration: str = None):
|
19 |
+
"""
|
20 |
+
Open an audio file and read as mono waveform, resampling as necessary
|
21 |
+
|
22 |
+
Parameters
|
23 |
+
----------
|
24 |
+
file: str
|
25 |
+
The audio file to open
|
26 |
+
|
27 |
+
sr: int
|
28 |
+
The sample rate to resample the audio if necessary
|
29 |
+
|
30 |
+
start_time: str
|
31 |
+
The start time, using the standard FFMPEG time duration syntax, or None to disable.
|
32 |
+
|
33 |
+
duration: str
|
34 |
+
The duration, using the standard FFMPEG time duration syntax, or None to disable.
|
35 |
+
|
36 |
+
Returns
|
37 |
+
-------
|
38 |
+
A NumPy array containing the audio waveform, in float32 dtype.
|
39 |
+
"""
|
40 |
+
try:
|
41 |
+
inputArgs = {'threads': 0}
|
42 |
+
|
43 |
+
if (start_time is not None):
|
44 |
+
inputArgs['ss'] = start_time
|
45 |
+
if (duration is not None):
|
46 |
+
inputArgs['t'] = duration
|
47 |
+
|
48 |
+
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
49 |
+
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
50 |
+
out, _ = (
|
51 |
+
ffmpeg.input(file, **inputArgs)
|
52 |
+
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
|
53 |
+
.run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
|
54 |
+
)
|
55 |
+
except ffmpeg.Error as e:
|
56 |
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
|
57 |
+
|
58 |
+
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
59 |
+
|
60 |
+
class VadTranscription:
|
61 |
+
def __init__(self):
|
62 |
+
self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
|
63 |
+
|
64 |
+
(self.get_speech_timestamps, _, _, _, _) = utils
|
65 |
+
|
66 |
+
def transcribe(self, audio: str, whisperCallable):
|
67 |
+
SAMPLING_RATE = 16000
|
68 |
+
wav = load_audio(audio, sample_rate=SAMPLING_RATE)
|
69 |
+
|
70 |
+
# get speech timestamps from full audio file
|
71 |
+
sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=SAMPLING_RATE, threshold=SPEECH_TRESHOLD)
|
72 |
+
seconds_timestamps = self.convert_seconds(sample_timestamps, sampling_rate=SAMPLING_RATE)
|
73 |
+
|
74 |
+
padded = self.pad_timestamps(seconds_timestamps, SEGMENT_PADDING_LEFT, SEGMENT_PADDING_RIGHT)
|
75 |
+
merged = self.merge_timestamps(padded, MAX_SILENT_PERIOD)
|
76 |
+
|
77 |
+
print("Timestamps:")
|
78 |
+
pprint(merged)
|
79 |
+
|
80 |
+
result = {
|
81 |
+
'text': "",
|
82 |
+
'segments': [],
|
83 |
+
'language': ""
|
84 |
+
}
|
85 |
+
languageCounter = Counter()
|
86 |
+
|
87 |
+
# For each time segment, run whisper
|
88 |
+
for segment in merged:
|
89 |
+
segment_start = segment['start']
|
90 |
+
segment_duration = segment['end'] - segment_start
|
91 |
+
|
92 |
+
segment_audio = load_audio(audio, sample_rate=SAMPLING_RATE, start_time = str(segment_start) + "s", duration = str(segment_duration) + "s")
|
93 |
+
|
94 |
+
print("Running whisper on " + str(segment_start) + ", duration: " + str(segment_duration))
|
95 |
+
segment_result = whisperCallable(segment_audio)
|
96 |
+
adjusted_segments = self.adjust_whisper_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
|
97 |
+
|
98 |
+
# Append to output
|
99 |
+
result['text'] += segment_result['text']
|
100 |
+
result['segments'].extend(adjusted_segments)
|
101 |
+
|
102 |
+
# Increment detected language
|
103 |
+
languageCounter[segment_result['language']] += 1
|
104 |
+
|
105 |
+
if len(languageCounter) > 0:
|
106 |
+
result['language'] = languageCounter.most_common(1)[0][0]
|
107 |
+
|
108 |
+
return result
|
109 |
+
|
110 |
+
def adjust_whisper_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
|
111 |
+
result = []
|
112 |
+
|
113 |
+
for segment in segments:
|
114 |
+
segment_start = float(segment['start'])
|
115 |
+
segment_end = float(segment['end'])
|
116 |
+
|
117 |
+
# Filter segments?
|
118 |
+
if (max_source_time is not None):
|
119 |
+
if (segment_start > max_source_time):
|
120 |
+
continue
|
121 |
+
segment_end = min(max_source_time, segment_end)
|
122 |
+
|
123 |
+
new_segment = segment.copy()
|
124 |
+
|
125 |
+
# Add to start and end
|
126 |
+
new_segment['start'] = segment_start + adjust_seconds
|
127 |
+
new_segment['end'] = segment_end + adjust_seconds
|
128 |
+
result.append(new_segment)
|
129 |
+
return result
|
130 |
+
|
131 |
+
def pad_timestamps(self, timestamps: List[Dict[str, Any]], padding_left: float, padding_right: float):
|
132 |
+
result = []
|
133 |
+
|
134 |
+
for entry in timestamps:
|
135 |
+
segment_start = entry['start']
|
136 |
+
segment_end = entry['end']
|
137 |
+
|
138 |
+
if padding_left is not None:
|
139 |
+
segment_start = max(0, segment_start - padding_left)
|
140 |
+
if padding_right is not None:
|
141 |
+
segment_end = segment_end + padding_right
|
142 |
+
|
143 |
+
result.append({ 'start': segment_start, 'end': segment_end })
|
144 |
+
|
145 |
+
return result
|
146 |
+
|
147 |
+
def merge_timestamps(self, timestamps: List[Dict[str, Any]], max_distance: float):
|
148 |
+
result = []
|
149 |
+
current_entry = None
|
150 |
+
|
151 |
+
for entry in timestamps:
|
152 |
+
if current_entry is None:
|
153 |
+
current_entry = entry
|
154 |
+
continue
|
155 |
+
|
156 |
+
# Get distance to the previous entry
|
157 |
+
distance = entry['start'] - current_entry['end']
|
158 |
+
|
159 |
+
if distance <= max_distance:
|
160 |
+
# Merge
|
161 |
+
current_entry['end'] = entry['end']
|
162 |
+
else:
|
163 |
+
# Output current entry
|
164 |
+
result.append(current_entry)
|
165 |
+
current_entry = entry
|
166 |
+
|
167 |
+
# Add final entry
|
168 |
+
if current_entry is not None:
|
169 |
+
result.append(current_entry)
|
170 |
+
|
171 |
+
return result
|
172 |
+
|
173 |
+
def convert_seconds(self, timestamps: List[Dict[str, Any]], sampling_rate: int):
|
174 |
+
result = []
|
175 |
+
|
176 |
+
for entry in timestamps:
|
177 |
+
start = entry['start']
|
178 |
+
end = entry['end']
|
179 |
+
|
180 |
+
result.append({
|
181 |
+
'start': start / sampling_rate,
|
182 |
+
'end': end / sampling_rate
|
183 |
+
})
|
184 |
+
return result
|
185 |
+
|