Refactor STT module for lazy loading and better code organization
Browse files- kitt/core/stt.py +16 -4
kitt/core/stt.py
CHANGED
@@ -8,10 +8,18 @@ import torchaudio
|
|
8 |
from loguru import logger
|
9 |
from transformers import pipeline
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
def save_audio_as_wav(data, sample_rate, file_path):
|
@@ -23,6 +31,10 @@ def save_audio_as_wav(data, sample_rate, file_path):
|
|
23 |
|
24 |
|
25 |
def transcribe_audio(audio):
|
|
|
|
|
|
|
|
|
26 |
sample_rate, data = audio
|
27 |
try:
|
28 |
data = data.astype(np.float32)
|
|
|
8 |
from loguru import logger
|
9 |
from transformers import pipeline
|
10 |
|
11 |
+
|
12 |
+
|
13 |
+
transcriber = None
|
14 |
+
|
15 |
+
|
16 |
+
def load_stt():
|
17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
transcriber = pipeline(
|
19 |
+
"automatic-speech-recognition", model="openai/whisper-base.en", device=device
|
20 |
+
)
|
21 |
+
return transcriber
|
22 |
+
|
23 |
|
24 |
|
25 |
def save_audio_as_wav(data, sample_rate, file_path):
|
|
|
31 |
|
32 |
|
33 |
def transcribe_audio(audio):
|
34 |
+
global transcriber
|
35 |
+
if transcriber is None:
|
36 |
+
transcriber = load_stt()
|
37 |
+
|
38 |
sample_rate, data = audio
|
39 |
try:
|
40 |
data = data.astype(np.float32)
|