Rakib commited on
Commit
5b106c5
1 Parent(s): 01f4d04

Initial commit

Browse files
Files changed (4) hide show
  1. gradio-app.py +193 -0
  2. models.py +132 -0
  3. utils.py +23 -0
  4. vad.py +273 -0
gradio-app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import time
3
+
4
+ import gradio as gr
5
+ import librosa
6
+ import pytube as pt
7
+ from models import asr, processor
8
+ from utils import format_timestamp
9
+ from vad import SpeechTimestampsMap, collect_chunks, get_speech_timestamps
10
+
11
+ ## details: https://huggingface.co/docs/diffusers/optimization/fp16#automatic-mixed-precision-amp
12
+ # from torch import autocast
13
+
14
+ apply_vad = True
15
+ vad_parameters = {}
16
+
17
+ # task = "transcribe" # transcribe or translate
18
+ # language = "bn"
19
+ # asr.model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
20
+ # asr.model.config.max_new_tokens = 448 #default is 448
21
+
22
+
23
+ def _preprocess(filename):
24
+ audio_name = "audio.wav"
25
+ subprocess.call(
26
+ [
27
+ "ffmpeg",
28
+ "-y",
29
+ "-i",
30
+ filename,
31
+ "-acodec",
32
+ "pcm_s16le",
33
+ "-ar",
34
+ "16000",
35
+ "-ac",
36
+ "1",
37
+ "-loglevel",
38
+ "quiet",
39
+ audio_name,
40
+ ]
41
+ )
42
+ return audio_name
43
+
44
+
45
+ def transcribe(microphone, file_upload):
46
+ warn_output = ""
47
+ if (microphone is not None) and (file_upload is not None):
48
+ warn_output = (
49
+ "WARNING: You've uploaded an audio file and used the microphone. "
50
+ "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
51
+ )
52
+
53
+ elif (microphone is None) and (file_upload is None):
54
+ return "ERROR: You have to either use the microphone or upload an audio file"
55
+
56
+ file = microphone if microphone is not None else file_upload
57
+ print(f"\n\nFile is: {file}\n\n")
58
+
59
+ # for _preprocess(). No need if name of file provided in string format to asr pipeline as automatically uses ffmeg.
60
+ # Only required if ndarray given by using librosa.load() to load a file
61
+ start_time = time.time()
62
+ print("Starting Preprocessing")
63
+ # speech_array = _preprocess(filename=file)
64
+ filename = _preprocess(filename=file)
65
+ speech_array, sample_rate = librosa.load(f"{filename}", sr=16_000)
66
+ if apply_vad:
67
+ duration = speech_array.shape[0] / sample_rate
68
+ print(f"Processing audio with duration: {format_timestamp(duration)}")
69
+ speech_chunks = get_speech_timestamps(speech_array, **vad_parameters)
70
+ speech_array = collect_chunks(speech_array, speech_chunks)
71
+ print(f"VAD filter removed {format_timestamp(duration - (speech_array.shape[0] / sample_rate))}")
72
+ remaining_segments = ", ".join(
73
+ f'[{format_timestamp(chunk["start"] / sample_rate)} -> {format_timestamp(chunk["end"] / sample_rate)}]'
74
+ for chunk in speech_chunks
75
+ )
76
+ print(f"VAD filter kept the following audio segments: {remaining_segments}")
77
+ if not remaining_segments:
78
+ return "ERROR: No speech detected in the audio file"
79
+
80
+
81
+
82
+ print(f"\n Preprocessing COMPLETED in {round(time.time()-start_time, 2)}s \n")
83
+
84
+ start_time = time.time()
85
+ print("Starting Inference")
86
+ text = asr(speech_array)["text"]
87
+ # text = asr(file)["text"]
88
+ # with autocast("cuda"):
89
+ # text = asr(speech_array)["text"]
90
+ print(f"\n Inference COMPLETED in {round(time.time()-start_time, 2)}s \n")
91
+
92
+ return warn_output + text
93
+
94
+
95
+ def _return_yt_html_embed(yt_url):
96
+ if "?v=" in yt_url:
97
+ video_id = yt_url.split("?v=")[-1].split("&")[0]
98
+ else:
99
+ video_id = yt_url.split("/")[-1].split("?feature=")[0]
100
+
101
+ print(f"\n\nYT ID is: {video_id}\n\n")
102
+ return f'<center><iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe> </center>'
103
+
104
+
105
+ def yt_transcribe(yt_url):
106
+ start_time = time.time()
107
+ yt = pt.YouTube(yt_url)
108
+ html_embed_str = _return_yt_html_embed(yt_url)
109
+ stream = yt.streams.filter(only_audio=True)[0]
110
+ filename = "audio.mp3"
111
+ stream.download(filename=filename)
112
+ print(f"\n YT Audio Downloaded in {round(time.time()-start_time, 2)}s \n")
113
+
114
+ # for _preprocess(). No need if name of file provided in string format to asr pipeline as automatically uses ffmeg.
115
+ # Only required if ndarray given by using librosa.load() to load a file
116
+ start_time = time.time()
117
+ # print("Starting Preprocessing")
118
+ # speech_array = _preprocess(filename=filename)
119
+ # filename = _preprocess(filename=filename)
120
+ # speech_array, sample_rate = librosa.load(f"{filename}", sr=16_000)
121
+ # print(f"\n Preprocessing COMPLETED in {round(time.time()-start_time, 2)}s \n")
122
+
123
+ start_time = time.time()
124
+ print("Starting Inference")
125
+ text = asr(filename)["text"]
126
+ # with autocast("cuda"):
127
+ # text = asr(speech_array)["text"]
128
+ print(f"\n Inference COMPLETED in {round(time.time()-start_time, 2)}s \n")
129
+
130
+ return html_embed_str, text
131
+
132
+
133
+ mf_transcribe = gr.Interface(
134
+ fn=transcribe,
135
+ inputs=[
136
+ gr.Audio(source="microphone", type="filepath", label="Microphone"),
137
+ gr.Audio(source="upload", type="filepath", label="Upload File"),
138
+ ],
139
+ outputs="text",
140
+ title="Bangla Demo: Transcribe Audio",
141
+ description=(
142
+ "Transcribe long-form microphone or audio inputs in BANGLA with the click of a button!"
143
+ ),
144
+ allow_flagging="never",
145
+ )
146
+
147
+ yt_transcribe = gr.Interface(
148
+ fn=yt_transcribe,
149
+ inputs=[
150
+ gr.Textbox(
151
+ lines=1,
152
+ placeholder="Paste the URL to a Bangla language YouTube video here",
153
+ label="YouTube URL",
154
+ )
155
+ ],
156
+ outputs=["html", "text"],
157
+ title="Bangla Demo: Transcribe YouTube",
158
+ description=(
159
+ "Transcribe long-form YouTube videos in BANGLA with the click of a button!"
160
+ ),
161
+ allow_flagging="never",
162
+ )
163
+ # def transcribe2(audio, state=""):
164
+ # text = "text"
165
+ # state += text + " "
166
+ # return state, state
167
+
168
+ # Set the starting state to an empty string
169
+
170
+ # real_transcribe = gr.Interface(
171
+ # fn=transcribe2,
172
+ # inputs=[
173
+ # gr.Audio(source="microphone", type="filepath", streaming=True),
174
+ # "state"
175
+ # ],
176
+ # outputs=[
177
+ # "textbox",
178
+ # "state"
179
+ # ],
180
+ # live=True)
181
+
182
+
183
+ # demo = gr.TabbedInterface([mf_transcribe, yt_transcribe,real_transcribe], ["Transcribe Bangla Audio", "Transcribe Bangla YouTube Video","real time"])
184
+ demo = gr.TabbedInterface(
185
+ [mf_transcribe, yt_transcribe],
186
+ ["Transcribe Bangla Audio", "Transcribe Bangla YouTube Video"],
187
+ )
188
+
189
+
190
+ if __name__ == "__main__":
191
+ demo.queue()
192
+ demo.launch(share="True")
193
+ # demo.launch(share='True', server_name="0.0.0.0", server_port=8080)
models.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ abs_path = os.path.abspath('.')
4
+ base_dir = os.path.dirname(os.path.dirname(abs_path))
5
+ os.environ['TRANSFORMERS_CACHE'] = os.path.join(base_dir, 'models_cache')
6
+
7
+ import torch
8
+ # Details: https://huggingface.co/docs/diffusers/optimization/fp16#enable-cudnn-autotuner
9
+ torch.backends.cudnn.benchmark = True
10
+ torch.backends.cuda.matmul.allow_tf32 = True
11
+ from transformers import pipeline, AutoTokenizer, AutoFeatureExtractor, AutoConfig, WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer, WhisperFeatureExtractor
12
+ from typing import Union, BinaryIO
13
+ # from optimum.bettertransformer import BetterTransformer
14
+
15
+ language = '<|bn|>'
16
+ # language = '<|en|>'
17
+ task = "transcribe" # transcribe or translate
18
+
19
+ # model_name = 'openai/whisper-tiny.en'
20
+ # model_name = 'openai/whisper-base.en'
21
+ # model_name = 'openai/whisper-small.en'
22
+ # model_name = 'openai/whisper-medium'
23
+ ## v2: trained on more epochs with regularization
24
+ # model_name = 'openai/whisper-large-v2'
25
+
26
+ ## bangla
27
+ # model_name = 'Rakib/whisper-tiny-bn'
28
+ #model_name = 'anuragshas/whisper-small-bn'
29
+ # model_name = 'anuragshas/whisper-large-v2-bn'
30
+ # model_name = "Rakib/whisper-small-bn"
31
+ # model_name = "Rakib/whisper-small-bn-all"
32
+ # model_name = "Rakib/whisper-small-bn-all-600"
33
+ # model_name = "Rakib/whisper-small-bn-all-600-v2"
34
+ model_name = "Rakib/whisper-small-bn-crblp"
35
+
36
+ ## lets you know the device count: cuda:0 or cuda:1
37
+ # print(torch.cuda.device_count())
38
+
39
+ device = 0 if torch.cuda.is_available() else -1
40
+ # device = -1 #Exclusively CPU
41
+
42
+ print(f"Using device: {'GPU' if device==0 else 'CPU'}")
43
+
44
+ if device !=0:
45
+ print("[Warning!] Using CPU could hamper performance")
46
+
47
+ print("Loading Tokenizer for ASR Speech-to-Text Model...\n" + "*" * 100)
48
+ # tokenizer = AutoTokenizer.from_pretrained(model_name, language=language, task=task)
49
+ # tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
50
+ tokenizer = WhisperTokenizer.from_pretrained(model_name)
51
+ # tokenizer(['�', '�্র'],add_prefix_space=True, add_special_tokens=False).input_ids
52
+
53
+ print("Loading Feature Extractor for ASR Speech-to-Text Model...\n" + "*" * 100)
54
+ # feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
55
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
56
+
57
+ print("Loading Config for ASR Speech-to-Text Model...\n" + "*" * 100)
58
+ config = AutoConfig.from_pretrained(model_name)
59
+
60
+ print("Loading Processor for ASR Speech-to-Text Model...\n" + "*" * 100)
61
+ processor = WhisperProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
62
+
63
+ print("Loading WHISPER ASR Speech-to-Text Model...\n" + "*" * 100)
64
+ model = WhisperForConditionalGeneration.from_pretrained(model_name)
65
+
66
+ ## BetterTransformer (No Need if PyTorch 2.0 works!!)
67
+ ## (currently 2secs faster inference than PyTorch 2.0 )
68
+ # model = WhisperForConditionalGeneration.from_pretrained(model_name)
69
+ # model = BetterTransformer.transform(model)
70
+
71
+ ## bitsandbytes (only Linux & GPU) (requires conda env with conda-based pytorch!!!)
72
+ ## currently only reduces size. slower inference than native models!!!
73
+ ## from_pretrained doc: https://huggingface.co/docs/transformers/v4.25.1/en/main_classes/model#transformers.PreTrainedModel.from_pretrained
74
+ # model = WhisperForConditionalGeneration.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
75
+
76
+ ## For PyTorch 2.0 (Only Linux)
77
+ # model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device="cuda:0")
78
+ ##mode options are "default", "reduce-overhead" and "max-autotune". See: https://pytorch.org/get-started/pytorch-2.0/#modes
79
+ # model = torch.compile(model, mode="default")
80
+
81
+
82
+ asr = pipeline(
83
+ task="automatic-speech-recognition",
84
+ model=model,
85
+ tokenizer=tokenizer,
86
+ feature_extractor=feature_extractor,
87
+ # processor=processor, #no effect see: https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/automatic_speech_recognition.py
88
+ # config=config, #no effect see: https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/automatic_speech_recognition.py
89
+ device=device, # for gpu 1 for cpu -1
90
+ ## chunk files longer than 30s into shorted samples
91
+ chunk_length_s=30,
92
+ ## the amount of overlap (in secs) to be discarded while stitching the inferenced chunks
93
+ ## stride_length_s is a tuple of the left and right stride(overlap) length.
94
+ ## With only 1 number, both sides get the same stride, by default
95
+ ## The stride_length on one side is 1/6th of the chunk_length_s if stride_length no provided
96
+ # stride_length_s=[8, 8],
97
+ stride_length_s=[5, 5],
98
+ # stride_length_s=[6,0],
99
+ batch_size=16,
100
+ ignore_warning=True,
101
+ ## force whisper to generate timestamps so that the chunking and stitching can be accurate
102
+ # return_timestamps=True,
103
+ generate_kwargs = {
104
+ 'language':language,
105
+ 'task':task,
106
+ 'repetition_penalty':1.8,
107
+ 'num_beams':2,
108
+ 'max_new_tokens':448,
109
+ 'early_stopping':True,
110
+ # 'renormalize_logits':True,
111
+ # [16867]: �, [16867, 156, 100, 235, 156, 12811]: �্র
112
+ 'bad_words_ids':[[16867], [16867, 156, 100, 235, 156, 12811]],
113
+ # 'supress_tokens': [16867, 156, 100, 235, 156, 12811],
114
+ }
115
+ )
116
+
117
+
118
+ def transcribe(speech_array: Union[str, BinaryIO], language: str = "en") -> str:
119
+ """
120
+ Transcribes an audio array to text
121
+ Args:
122
+ speech_array (np.ndarray): audio in numpy array format
123
+ language (str): "sv" or "en"
124
+ Returns:
125
+ a string containing transcription
126
+ """
127
+ asr.model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
128
+ # asr.model.config.max_new_tokens = 448 #default is 448
129
+
130
+ result = asr(speech_array)
131
+
132
+ return str(result["text"])
utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ def format_timestamp(
4
+ seconds: float,
5
+ always_include_hours: bool = False,
6
+ decimal_marker: str = ".",
7
+ ) -> str:
8
+ assert seconds >= 0, "non-negative timestamp expected"
9
+ milliseconds = round(seconds * 1000.0)
10
+
11
+ hours = milliseconds // 3_600_000
12
+ milliseconds -= hours * 3_600_000
13
+
14
+ minutes = milliseconds // 60_000
15
+ milliseconds -= minutes * 60_000
16
+
17
+ seconds = milliseconds // 1_000
18
+ milliseconds -= seconds * 1_000
19
+
20
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
21
+ return (
22
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
23
+ )
vad.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import functools
3
+ import os
4
+ import warnings
5
+
6
+ from typing import List, Optional
7
+
8
+ import numpy as np
9
+
10
+ # The code below is adapted from https://github.com/snakers4/silero-vad.
11
+
12
+ def get_assets_path():
13
+ """Returns the path to the assets directory."""
14
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
15
+
16
+
17
+ def get_speech_timestamps(
18
+ audio: np.ndarray,
19
+ *,
20
+ threshold: float = 0.5,
21
+ # min_speech_duration_ms: int = 250,
22
+ min_speech_duration_ms: int = 800,
23
+ max_speech_duration_s: float = float("inf"),
24
+ # min_silence_duration_ms: int = 2000,
25
+ min_silence_duration_ms: int = 1000,
26
+ window_size_samples: int = 1024,
27
+ speech_pad_ms: int = 200,
28
+ ) -> List[dict]:
29
+ """This method is used for splitting long audios into speech chunks using silero VAD.
30
+ Args:
31
+ audio: One dimensional float array.
32
+ threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
33
+ probabilities ABOVE this value are considered as SPEECH. It is better to tune this
34
+ parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
35
+ min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
36
+ max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
37
+ than max_speech_duration_s will be split at the timestamp of the last silence that
38
+ lasts more than 100s (if any), to prevent agressive cutting. Otherwise, they will be
39
+ split aggressively just before max_speech_duration_s.
40
+ min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
41
+ before separating it
42
+ window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
43
+ WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
44
+ Values other than these may affect model perfomance!!
45
+ speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
46
+ Returns:
47
+ List of dicts containing begin and end samples of each speech chunk.
48
+ """
49
+ if window_size_samples not in [512, 1024, 1536]:
50
+ warnings.warn(
51
+ "Unusual window_size_samples! Supported window_size_samples:\n"
52
+ " - [512, 1024, 1536] for 16000 sampling_rate"
53
+ )
54
+
55
+ sampling_rate = 16000
56
+ min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
57
+ speech_pad_samples = sampling_rate * speech_pad_ms / 1000
58
+ max_speech_samples = (
59
+ sampling_rate * max_speech_duration_s
60
+ - window_size_samples
61
+ - 2 * speech_pad_samples
62
+ )
63
+ min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
64
+ min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
65
+
66
+ audio_length_samples = len(audio)
67
+
68
+ model = get_vad_model()
69
+ state = model.get_initial_state(batch_size=1)
70
+
71
+ speech_probs = []
72
+ for current_start_sample in range(0, audio_length_samples, window_size_samples):
73
+ chunk = audio[current_start_sample : current_start_sample + window_size_samples]
74
+ if len(chunk) < window_size_samples:
75
+ chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
76
+ speech_prob, state = model(chunk, state, sampling_rate)
77
+ speech_probs.append(speech_prob)
78
+
79
+ triggered = False
80
+ speeches = []
81
+ current_speech = {}
82
+ neg_threshold = threshold - 0.15
83
+
84
+ # to save potential segment end (and tolerate some silence)
85
+ temp_end = 0
86
+ # to save potential segment limits in case of maximum segment size reached
87
+ prev_end = next_start = 0
88
+
89
+ for i, speech_prob in enumerate(speech_probs):
90
+ if (speech_prob >= threshold) and temp_end:
91
+ temp_end = 0
92
+ if next_start < prev_end:
93
+ next_start = window_size_samples * i
94
+
95
+ if (speech_prob >= threshold) and not triggered:
96
+ triggered = True
97
+ current_speech["start"] = window_size_samples * i
98
+ continue
99
+
100
+ if (
101
+ triggered
102
+ and (window_size_samples * i) - current_speech["start"] > max_speech_samples
103
+ ):
104
+ if prev_end:
105
+ current_speech["end"] = prev_end
106
+ speeches.append(current_speech)
107
+ current_speech = {}
108
+ # previously reached silence (< neg_thres) and is still not speech (< thres)
109
+ if next_start < prev_end:
110
+ triggered = False
111
+ else:
112
+ current_speech["start"] = next_start
113
+ prev_end = next_start = temp_end = 0
114
+ else:
115
+ current_speech["end"] = window_size_samples * i
116
+ speeches.append(current_speech)
117
+ current_speech = {}
118
+ prev_end = next_start = temp_end = 0
119
+ triggered = False
120
+ continue
121
+
122
+ if (speech_prob < neg_threshold) and triggered:
123
+ if not temp_end:
124
+ temp_end = window_size_samples * i
125
+ # condition to avoid cutting in very short silence
126
+ if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
127
+ prev_end = temp_end
128
+ if (window_size_samples * i) - temp_end < min_silence_samples:
129
+ continue
130
+ else:
131
+ current_speech["end"] = temp_end
132
+ if (
133
+ current_speech["end"] - current_speech["start"]
134
+ ) > min_speech_samples:
135
+ speeches.append(current_speech)
136
+ current_speech = {}
137
+ prev_end = next_start = temp_end = 0
138
+ triggered = False
139
+ continue
140
+
141
+ if (
142
+ current_speech
143
+ and (audio_length_samples - current_speech["start"]) > min_speech_samples
144
+ ):
145
+ current_speech["end"] = audio_length_samples
146
+ speeches.append(current_speech)
147
+
148
+ for i, speech in enumerate(speeches):
149
+ if i == 0:
150
+ speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
151
+ if i != len(speeches) - 1:
152
+ silence_duration = speeches[i + 1]["start"] - speech["end"]
153
+ if silence_duration < 2 * speech_pad_samples:
154
+ speech["end"] += int(silence_duration // 2)
155
+ speeches[i + 1]["start"] = int(
156
+ max(0, speeches[i + 1]["start"] - silence_duration // 2)
157
+ )
158
+ else:
159
+ speech["end"] = int(
160
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
161
+ )
162
+ speeches[i + 1]["start"] = int(
163
+ max(0, speeches[i + 1]["start"] - speech_pad_samples)
164
+ )
165
+ else:
166
+ speech["end"] = int(
167
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
168
+ )
169
+
170
+ return speeches
171
+
172
+
173
+ def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
174
+ """Collects and concatenates audio chunks."""
175
+ if not chunks:
176
+ return np.array([], dtype=np.float32)
177
+
178
+ return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
179
+
180
+
181
+ class SpeechTimestampsMap:
182
+ """Helper class to restore original speech timestamps."""
183
+
184
+ def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2):
185
+ self.sampling_rate = sampling_rate
186
+ self.time_precision = time_precision
187
+ self.chunk_end_sample = []
188
+ self.total_silence_before = []
189
+
190
+ previous_end = 0
191
+ silent_samples = 0
192
+
193
+ for chunk in chunks:
194
+ silent_samples += chunk["start"] - previous_end
195
+ previous_end = chunk["end"]
196
+
197
+ self.chunk_end_sample.append(chunk["end"] - silent_samples)
198
+ self.total_silence_before.append(silent_samples / sampling_rate)
199
+
200
+ def get_original_time(
201
+ self,
202
+ time: float,
203
+ chunk_index: Optional[int] = None,
204
+ ) -> float:
205
+ if chunk_index is None:
206
+ chunk_index = self.get_chunk_index(time)
207
+
208
+ total_silence_before = self.total_silence_before[chunk_index]
209
+ return round(total_silence_before + time, self.time_precision)
210
+
211
+ def get_chunk_index(self, time: float) -> int:
212
+ sample = int(time * self.sampling_rate)
213
+ return min(
214
+ bisect.bisect(self.chunk_end_sample, sample),
215
+ len(self.chunk_end_sample) - 1,
216
+ )
217
+
218
+
219
+ @functools.lru_cache
220
+ def get_vad_model():
221
+ """Returns the VAD model instance."""
222
+ path = os.path.join(get_assets_path(), "silero_vad.onnx")
223
+ return SileroVADModel(path)
224
+
225
+
226
+ class SileroVADModel:
227
+ def __init__(self, path):
228
+ try:
229
+ import onnxruntime
230
+ except ImportError as e:
231
+ raise RuntimeError(
232
+ "Applying the VAD filter requires the onnxruntime package"
233
+ ) from e
234
+
235
+ opts = onnxruntime.SessionOptions()
236
+ opts.inter_op_num_threads = 1
237
+ opts.intra_op_num_threads = 1
238
+ opts.log_severity_level = 4
239
+
240
+ self.session = onnxruntime.InferenceSession(
241
+ path,
242
+ providers=["CPUExecutionProvider"],
243
+ sess_options=opts,
244
+ )
245
+
246
+ def get_initial_state(self, batch_size: int):
247
+ h = np.zeros((2, batch_size, 64), dtype=np.float32)
248
+ c = np.zeros((2, batch_size, 64), dtype=np.float32)
249
+ return h, c
250
+
251
+ def __call__(self, x, state, sr: int):
252
+ if len(x.shape) == 1:
253
+ x = np.expand_dims(x, 0)
254
+ if len(x.shape) > 2:
255
+ raise ValueError(
256
+ f"Too many dimensions for input audio chunk {len(x.shape)}"
257
+ )
258
+ if sr / x.shape[1] > 31.25:
259
+ raise ValueError("Input audio chunk is too short")
260
+
261
+ h, c = state
262
+
263
+ ort_inputs = {
264
+ "input": x,
265
+ "h": h,
266
+ "c": c,
267
+ "sr": np.array(sr, dtype="int64"),
268
+ }
269
+
270
+ out, h, c = self.session.run(None, ort_inputs)
271
+ state = (h, c)
272
+
273
+ return out, state