interviewer / api /audio.py
IliaLarchenko's picture
Minor refactoring
ab840dc
raw
history blame
8.18 kB
import io
import wave
import numpy as np
import requests
from openai import OpenAI
from utils.errors import APIError, AudioConversionError
class STTManager:
def __init__(self, config):
self.SAMPLE_RATE = 48000
self.CHUNK_LENGTH = 5
self.STEP_LENGTH = 3
self.MAX_RELIABILITY_CUTOFF = self.CHUNK_LENGTH - 1
self.config = config
self.status = self.test_stt()
self.streaming = self.test_streaming()
def numpy_audio_to_bytes(self, audio_data):
num_channels = 1
sampwidth = 2
buffer = io.BytesIO()
try:
with wave.open(buffer, "wb") as wf:
wf.setnchannels(num_channels)
wf.setsampwidth(sampwidth)
wf.setframerate(self.SAMPLE_RATE)
wf.writeframes(audio_data.tobytes())
except Exception as e:
raise AudioConversionError(f"Error converting numpy array to audio bytes: {e}")
return buffer.getvalue()
def process_audio_chunk(self, audio, audio_buffer, transcript):
"""Process streamed audio data to accumulate and transcribe with overlapping segments."""
audio_buffer = np.concatenate((audio_buffer, audio[1]))
if len(audio_buffer) >= self.SAMPLE_RATE * self.CHUNK_LENGTH or len(audio_buffer) % (self.SAMPLE_RATE // 2) != 0:
audio_bytes = self.numpy_audio_to_bytes(audio_buffer[: self.SAMPLE_RATE * self.CHUNK_LENGTH])
audio_buffer = audio_buffer[self.SAMPLE_RATE * self.STEP_LENGTH :]
new_transcript = self.speech_to_text_stream(audio_bytes)
transcript = self.merge_transcript(transcript, new_transcript)
return transcript, audio_buffer, transcript["text"]
def speech_to_text_stream(self, audio):
if self.config.stt.type == "HF_API":
raise APIError("STT Error: Streaming not supported for this STT type")
try:
data = ("temp.wav", audio, "audio/wav")
client = OpenAI(base_url=self.config.stt.url, api_key=self.config.stt.key)
transcription = client.audio.transcriptions.create(
model=self.config.stt.name, file=data, response_format="verbose_json", timestamp_granularities=["word"]
)
except APIError as e:
raise
except Exception as e:
raise APIError(f"STT Error: Unexpected error: {e}")
return transcription.words
def merge_transcript(self, transcript, new_transcript):
cut_off = transcript["last_cutoff"]
transcript["last_cutoff"] = self.MAX_RELIABILITY_CUTOFF - self.STEP_LENGTH
transcript["words"] = transcript["words"][: len(transcript["words"]) - transcript["not_confirmed"]]
transcript["not_confirmed"] = 0
first_word = True
for word_dict in new_transcript:
if word_dict["start"] >= cut_off:
if first_word:
if len(transcript["words"]) > 0 and transcript["words"][-1] == word_dict["word"]:
continue
first_word = False
transcript["words"].append(word_dict["word"])
if word_dict["start"] > self.MAX_RELIABILITY_CUTOFF:
transcript["not_confirmed"] += 1
else:
transcript["last_cutoff"] = max(1.0, word_dict["end"] - self.STEP_LENGTH)
transcript["text"] = " ".join(transcript["words"])
return transcript
def speech_to_text_full(self, audio):
audio = self.numpy_audio_to_bytes(audio[1])
try:
if self.config.stt.type == "OPENAI_API":
data = ("temp.wav", audio, "audio/wav")
client = OpenAI(base_url=self.config.stt.url, api_key=self.config.stt.key)
transcription = client.audio.transcriptions.create(model=self.config.stt.name, file=data, response_format="text")
elif self.config.stt.type == "HF_API":
headers = {"Authorization": "Bearer " + self.config.stt.key}
response = requests.post(self.config.stt.url, headers=headers, data=audio)
if response.status_code != 200:
error_details = response.json().get("error", "No error message provided")
raise APIError("STT Error: HF API error", status_code=response.status_code, details=error_details)
transcription = response.json().get("text", None)
if transcription is None:
raise APIError("STT Error: No transcription returned by HF API")
except APIError as e:
raise
except Exception as e:
raise APIError(f"STT Error: Unexpected error: {e}")
return transcription
def test_stt(self):
try:
self.speech_to_text_full((48000, np.zeros(10000)))
return True
except:
return False
def test_streaming(self):
try:
self.speech_to_text_stream(self.numpy_audio_to_bytes(np.zeros(10000)))
return True
except:
return False
class TTSManager:
def test_tts(self):
try:
self.read_text("Handshake")
return True
except:
return False
def test_tts_stream(self):
try:
for _ in self.read_text_stream("Handshake"):
pass
return True
except:
return False
def __init__(self, config):
self.config = config
self.status = self.test_tts()
if self.status:
self.streaming = self.test_tts_stream()
else:
self.streaming = False
if self.streaming:
self.read_last_message = self.rlm_stream
else:
self.read_last_message = self.rlm
def read_text(self, text):
headers = {"Authorization": "Bearer " + self.config.tts.key}
try:
if self.config.tts.type == "OPENAI_API":
data = {"model": self.config.tts.name, "input": text, "voice": "alloy", "response_format": "opus"}
response = requests.post(self.config.tts.url + "/audio/speech", headers=headers, json=data)
elif self.config.tts.type == "HF_API":
response = requests.post(self.config.tts.url, headers=headers, json={"inputs": text})
if response.status_code != 200:
error_details = response.json().get("error", "No error message provided")
raise APIError(f"TTS Error: {self.config.tts.type} error", status_code=response.status_code, details=error_details)
except APIError as e:
raise
except Exception as e:
raise APIError(f"TTS Error: Unexpected error: {e}")
return response.content
def read_text_stream(self, text):
if self.config.tts.type not in ["OPENAI_API"]:
raise APIError("TTS Error: Streaming not supported for this TTS type")
headers = {"Authorization": "Bearer " + self.config.tts.key}
data = {"model": self.config.tts.name, "input": text, "voice": "alloy", "response_format": "opus"}
try:
with requests.post(self.config.tts.url + "/audio/speech", headers=headers, json=data, stream=True) as response:
if response.status_code != 200:
error_details = response.json().get("error", "No error message provided")
raise APIError("TTS Error: OPENAI API error", status_code=response.status_code, details=error_details)
else:
yield from response.iter_content(chunk_size=1024)
except StopIteration:
pass
except APIError as e:
raise
except Exception as e:
raise APIError(f"TTS Error: Unexpected error: {e}")
def rlm(self, chat_history):
if len(chat_history) > 0 and chat_history[-1][1]:
return self.read_text(chat_history[-1][1])
def rlm_stream(self, chat_history):
if len(chat_history) > 0 and chat_history[-1][1]:
yield from self.read_text_stream(chat_history[-1][1])