Spaces:
Running
Running
File size: 2,434 Bytes
9712d04 780954b 9712d04 780954b 9712d04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
from typing import Optional
import librosa
import numpy as np
from transformers.pipelines import pipeline
from .base import AbstractASRModel
from .registry import register_asr_model
hf_token = os.getenv("HF_TOKEN")
@register_asr_model("openai/whisper")
class WhisperASR(AbstractASRModel):
def __init__(
self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
):
super().__init__(model_id, device, cache_dir, **kwargs)
model_kwargs = kwargs.setdefault("model_kwargs", {})
model_kwargs["cache_dir"] = cache_dir
self.pipe = pipeline(
"automatic-speech-recognition",
model=model_id,
device_map=device,
token=hf_token,
**kwargs,
)
def transcribe(
self,
audio: np.ndarray,
audio_sample_rate: int,
language: Optional[str] = None,
**kwargs,
) -> str:
"""
Transcribe audio using Whisper model
Args:
audio: Audio numpy array
audio_sample_rate: Sample rate of the audio
language: Language hint (optional)
Returns:
Transcribed text as string
"""
try:
# Resample to 16kHz if needed
if audio_sample_rate != 16000:
audio = librosa.resample(
audio, orig_sr=audio_sample_rate, target_sr=16000
)
# Generate transcription
generate_kwargs = {}
if language:
generate_kwargs["language"] = language
result = self.pipe(
audio,
generate_kwargs=generate_kwargs,
return_timestamps=False,
**kwargs,
)
# Extract text from result
if isinstance(result, dict) and "text" in result:
return result["text"]
elif isinstance(result, list) and len(result) > 0:
# Handle list of results
first_result = result[0]
if isinstance(first_result, dict):
return first_result.get("text", str(first_result))
else:
return str(first_result)
else:
return str(result)
except Exception as e:
print(f"Error during Whisper transcription: {e}")
return ""
|