IliaLarchenko commited on
Commit
9c01649
1 Parent(s): 5efe29a

Added support of local STT models using HuggingFace

Browse files
Files changed (1) hide show
  1. api/audio.py +11 -2
api/audio.py CHANGED
@@ -10,6 +10,8 @@ from utils.errors import APIError, AudioConversionError
10
  from typing import List, Dict, Optional, Generator, Tuple
11
  import webrtcvad
12
 
 
 
13
 
14
  def detect_voice(audio: np.ndarray, sample_rate: int = 48000, frame_duration: int = 30) -> bool:
15
  vad = webrtcvad.Vad()
@@ -43,6 +45,9 @@ class STTManager:
43
  self.status = self.test_stt()
44
  self.streaming = self.status
45
 
 
 
 
46
  def numpy_audio_to_bytes(self, audio_data: np.ndarray) -> bytes:
47
  """
48
  Convert a numpy array of audio data to bytes.
@@ -79,7 +84,7 @@ class STTManager:
79
  if has_voice:
80
  audio_buffer = np.concatenate((audio_buffer, audio[1]))
81
 
82
- is_short = len(audio_buffer) / 48000 < 1.0
83
 
84
  if is_short or (has_voice and not ended):
85
  return audio_buffer, np.array([], dtype=np.int16)
@@ -101,15 +106,16 @@ class STTManager:
101
  :param context: Optional context for the transcription.
102
  :return: Transcribed text.
103
  """
104
- audio_bytes = self.numpy_audio_to_bytes(audio)
105
  try:
106
  if self.config.stt.type == "OPENAI_API":
 
107
  data = ("temp.wav", audio_bytes, "audio/wav")
108
  client = OpenAI(base_url=self.config.stt.url, api_key=self.config.stt.key)
109
  transcription = client.audio.transcriptions.create(
110
  model=self.config.stt.name, file=data, response_format="text", prompt=context
111
  )
112
  elif self.config.stt.type == "HF_API":
 
113
  headers = {"Authorization": "Bearer " + self.config.stt.key}
114
  response = requests.post(self.config.stt.url, headers=headers, data=audio_bytes)
115
  if response.status_code != 200:
@@ -118,6 +124,9 @@ class STTManager:
118
  transcription = response.json().get("text", None)
119
  if transcription is None:
120
  raise APIError("STT Error: No transcription returned by HF API")
 
 
 
121
  except APIError:
122
  raise
123
  except Exception as e:
 
10
  from typing import List, Dict, Optional, Generator, Tuple
11
  import webrtcvad
12
 
13
+ from transformers import pipeline
14
+
15
 
16
  def detect_voice(audio: np.ndarray, sample_rate: int = 48000, frame_duration: int = 30) -> bool:
17
  vad = webrtcvad.Vad()
 
45
  self.status = self.test_stt()
46
  self.streaming = self.status
47
 
48
+ if config.stt.type == "HF_LOCAL":
49
+ self.pipe = pipeline("automatic-speech-recognition", model=config.stt.name)
50
+
51
  def numpy_audio_to_bytes(self, audio_data: np.ndarray) -> bytes:
52
  """
53
  Convert a numpy array of audio data to bytes.
 
84
  if has_voice:
85
  audio_buffer = np.concatenate((audio_buffer, audio[1]))
86
 
87
+ is_short = len(audio_buffer) / self.SAMPLE_RATE < 1.0
88
 
89
  if is_short or (has_voice and not ended):
90
  return audio_buffer, np.array([], dtype=np.int16)
 
106
  :param context: Optional context for the transcription.
107
  :return: Transcribed text.
108
  """
 
109
  try:
110
  if self.config.stt.type == "OPENAI_API":
111
+ audio_bytes = self.numpy_audio_to_bytes(audio)
112
  data = ("temp.wav", audio_bytes, "audio/wav")
113
  client = OpenAI(base_url=self.config.stt.url, api_key=self.config.stt.key)
114
  transcription = client.audio.transcriptions.create(
115
  model=self.config.stt.name, file=data, response_format="text", prompt=context
116
  )
117
  elif self.config.stt.type == "HF_API":
118
+ audio_bytes = self.numpy_audio_to_bytes(audio)
119
  headers = {"Authorization": "Bearer " + self.config.stt.key}
120
  response = requests.post(self.config.stt.url, headers=headers, data=audio_bytes)
121
  if response.status_code != 200:
 
124
  transcription = response.json().get("text", None)
125
  if transcription is None:
126
  raise APIError("STT Error: No transcription returned by HF API")
127
+ elif self.config.stt.type == "HF_LOCAL":
128
+ result = self.pipe({"sampling_rate": self.SAMPLE_RATE, "raw": audio.astype(np.float32) / 32768.0})
129
+ transcription = result["text"]
130
  except APIError:
131
  raise
132
  except Exception as e: