Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
IliaLarchenko
commited on
Commit
•
9c01649
1
Parent(s):
5efe29a
Added support of local STT models using HuggingFace
Browse files- api/audio.py +11 -2
api/audio.py
CHANGED
@@ -10,6 +10,8 @@ from utils.errors import APIError, AudioConversionError
|
|
10 |
from typing import List, Dict, Optional, Generator, Tuple
|
11 |
import webrtcvad
|
12 |
|
|
|
|
|
13 |
|
14 |
def detect_voice(audio: np.ndarray, sample_rate: int = 48000, frame_duration: int = 30) -> bool:
|
15 |
vad = webrtcvad.Vad()
|
@@ -43,6 +45,9 @@ class STTManager:
|
|
43 |
self.status = self.test_stt()
|
44 |
self.streaming = self.status
|
45 |
|
|
|
|
|
|
|
46 |
def numpy_audio_to_bytes(self, audio_data: np.ndarray) -> bytes:
|
47 |
"""
|
48 |
Convert a numpy array of audio data to bytes.
|
@@ -79,7 +84,7 @@ class STTManager:
|
|
79 |
if has_voice:
|
80 |
audio_buffer = np.concatenate((audio_buffer, audio[1]))
|
81 |
|
82 |
-
is_short = len(audio_buffer) /
|
83 |
|
84 |
if is_short or (has_voice and not ended):
|
85 |
return audio_buffer, np.array([], dtype=np.int16)
|
@@ -101,15 +106,16 @@ class STTManager:
|
|
101 |
:param context: Optional context for the transcription.
|
102 |
:return: Transcribed text.
|
103 |
"""
|
104 |
-
audio_bytes = self.numpy_audio_to_bytes(audio)
|
105 |
try:
|
106 |
if self.config.stt.type == "OPENAI_API":
|
|
|
107 |
data = ("temp.wav", audio_bytes, "audio/wav")
|
108 |
client = OpenAI(base_url=self.config.stt.url, api_key=self.config.stt.key)
|
109 |
transcription = client.audio.transcriptions.create(
|
110 |
model=self.config.stt.name, file=data, response_format="text", prompt=context
|
111 |
)
|
112 |
elif self.config.stt.type == "HF_API":
|
|
|
113 |
headers = {"Authorization": "Bearer " + self.config.stt.key}
|
114 |
response = requests.post(self.config.stt.url, headers=headers, data=audio_bytes)
|
115 |
if response.status_code != 200:
|
@@ -118,6 +124,9 @@ class STTManager:
|
|
118 |
transcription = response.json().get("text", None)
|
119 |
if transcription is None:
|
120 |
raise APIError("STT Error: No transcription returned by HF API")
|
|
|
|
|
|
|
121 |
except APIError:
|
122 |
raise
|
123 |
except Exception as e:
|
|
|
10 |
from typing import List, Dict, Optional, Generator, Tuple
|
11 |
import webrtcvad
|
12 |
|
13 |
+
from transformers import pipeline
|
14 |
+
|
15 |
|
16 |
def detect_voice(audio: np.ndarray, sample_rate: int = 48000, frame_duration: int = 30) -> bool:
|
17 |
vad = webrtcvad.Vad()
|
|
|
45 |
self.status = self.test_stt()
|
46 |
self.streaming = self.status
|
47 |
|
48 |
+
if config.stt.type == "HF_LOCAL":
|
49 |
+
self.pipe = pipeline("automatic-speech-recognition", model=config.stt.name)
|
50 |
+
|
51 |
def numpy_audio_to_bytes(self, audio_data: np.ndarray) -> bytes:
|
52 |
"""
|
53 |
Convert a numpy array of audio data to bytes.
|
|
|
84 |
if has_voice:
|
85 |
audio_buffer = np.concatenate((audio_buffer, audio[1]))
|
86 |
|
87 |
+
is_short = len(audio_buffer) / self.SAMPLE_RATE < 1.0
|
88 |
|
89 |
if is_short or (has_voice and not ended):
|
90 |
return audio_buffer, np.array([], dtype=np.int16)
|
|
|
106 |
:param context: Optional context for the transcription.
|
107 |
:return: Transcribed text.
|
108 |
"""
|
|
|
109 |
try:
|
110 |
if self.config.stt.type == "OPENAI_API":
|
111 |
+
audio_bytes = self.numpy_audio_to_bytes(audio)
|
112 |
data = ("temp.wav", audio_bytes, "audio/wav")
|
113 |
client = OpenAI(base_url=self.config.stt.url, api_key=self.config.stt.key)
|
114 |
transcription = client.audio.transcriptions.create(
|
115 |
model=self.config.stt.name, file=data, response_format="text", prompt=context
|
116 |
)
|
117 |
elif self.config.stt.type == "HF_API":
|
118 |
+
audio_bytes = self.numpy_audio_to_bytes(audio)
|
119 |
headers = {"Authorization": "Bearer " + self.config.stt.key}
|
120 |
response = requests.post(self.config.stt.url, headers=headers, data=audio_bytes)
|
121 |
if response.status_code != 200:
|
|
|
124 |
transcription = response.json().get("text", None)
|
125 |
if transcription is None:
|
126 |
raise APIError("STT Error: No transcription returned by HF API")
|
127 |
+
elif self.config.stt.type == "HF_LOCAL":
|
128 |
+
result = self.pipe({"sampling_rate": self.SAMPLE_RATE, "raw": audio.astype(np.float32) / 32768.0})
|
129 |
+
transcription = result["text"]
|
130 |
except APIError:
|
131 |
raise
|
132 |
except Exception as e:
|