jts-ai-team's picture
Upload 7 files
abb09c3 verified
"""Speech-to-text utilities with graceful fallbacks."""
from __future__ import annotations
import numpy as np
from backend.utils import device
import nemo.collections.asr as nemo_asr
try:
import torch
from transformers import pipeline
except ModuleNotFoundError: # PyTorch or transformers not available on Python 3.13 wheels
torch = None # type: ignore
pipeline = None # type: ignore
try:
from google.cloud import speech
except ModuleNotFoundError:
speech = None # type: ignore
_ASR_PIPELINE = None
def _huggingface_device() -> int | str | None:
if device == "cuda":
return 0
if device == "mps":
return "mps"
return None
def _initialize_typhoon_pipeline():
if torch is None or pipeline is None:
return None
device = 'cuda' if torch.cuda.is_available() else 'mps'
print(f"Using device: {device}")
print("Initializing Typhoon ASR pipeline...")
asr_model = nemo_asr.models.ASRModel.from_pretrained(
model_name="scb10x/typhoon-asr-realtime",
map_location=device
)
print("Typhoon ASR pipeline initialized.")
return asr_model
def _initialize_whisper_pipeline():
pipe = pipeline(
task="automatic-speech-recognition",
model="nectec/Pathumma-whisper-th-medium",
chunk_length_s=30,
device=device,
model_kwargs={"torch_dtype": torch.bfloat16},
)
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(
language='th',
task="transcribe"
)
return pipe
_ASR_TYPHOON = None
# _ASR_TYPHOON = _initialize_typhoon_pipeline()
_ASR_WHISPER = _initialize_whisper_pipeline()
def _transcribe_with_pipeline(audio_array: np.ndarray) -> str:
output = _ASR_PIPELINE(audio_array) # type: ignore[operator]
if isinstance(output, dict):
text = output.get("text", "")
else:
text = str(output)
return text.replace("ทางลัด", "ทางรัฐ")
def _transcribe_with_google(audio_array: np.ndarray) -> str:
if speech is None:
raise RuntimeError("google-cloud-speech is not available")
int16_audio = (audio_array * 32767.0).astype(np.int16)
audio_bytes = int16_audio.tobytes()
client = speech.SpeechClient()
audio_config = speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
sample_rate_hertz=16000,
language_code="th-TH",
alternative_language_codes=["en-US"],
model = "telephony"
)
audio_data = speech.RecognitionAudio(content=audio_bytes)
response = client.recognize(config=audio_config, audio=audio_data)
transcription = " ".join(
result.alternatives[0].transcript for result in response.results
)
return transcription
def transcribe_audio(audio_array: np.ndarray) -> str:
"""Transcribe user audio with the best available backend."""
if audio_array is None or not np.any(audio_array):
return ""
# if _ASR_TYPHOON:
# try:
# transcriptions = _ASR_PIPELINE.transcribe(audio=audio_array)
# except Exception as exc:
# print(f"Typhoon ASR pipeline failed: {exc}")
if _ASR_WHISPER:
try:
transcription = _ASR_WHISPER(audio_array)["text"]
return transcription
except Exception as exc:
print(f"Typhoon ASR pipeline failed: {exc}")
try:
return _transcribe_with_google(audio_array)
except Exception as exc:
print(f"ASR fallback failed: {exc}")
return ""