|
|
|
|
|
|
|
|
|
|
|
|
|
"""This module implements Whisper transcription with a locally-downloaded model."""
|
|
|
|
import asyncio
|
|
import time
|
|
|
|
from enum import Enum
|
|
from typing_extensions import AsyncGenerator
|
|
|
|
import numpy as np
|
|
|
|
from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
|
|
from pipecat.services.ai_services import STTService
|
|
|
|
from loguru import logger
|
|
|
|
try:
|
|
from faster_whisper import WhisperModel
|
|
except ModuleNotFoundError as e:
|
|
logger.error(f"Exception: {e}")
|
|
logger.error(
|
|
"In order to use Whisper, you need to `pip install pipecat-ai[whisper]`.")
|
|
raise Exception(f"Missing module: {e}")
|
|
|
|
|
|
class Model(Enum):
|
|
"""Class of basic Whisper model selection options"""
|
|
TINY = "tiny"
|
|
BASE = "base"
|
|
MEDIUM = "medium"
|
|
LARGE = "large-v3"
|
|
DISTIL_LARGE_V2 = "Systran/faster-distil-whisper-large-v2"
|
|
DISTIL_MEDIUM_EN = "Systran/faster-distil-whisper-medium.en"
|
|
|
|
|
|
class WhisperSTTService(STTService):
|
|
"""Class to transcribe audio with a locally-downloaded Whisper model"""
|
|
|
|
def __init__(self,
|
|
*,
|
|
model: str | Model = Model.DISTIL_MEDIUM_EN,
|
|
device: str = "auto",
|
|
compute_type: str = "default",
|
|
no_speech_prob: float = 0.4,
|
|
**kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
self._device: str = device
|
|
self._compute_type = compute_type
|
|
self._model_name: str | Model = model
|
|
self._no_speech_prob = no_speech_prob
|
|
self._model: WhisperModel | None = None
|
|
self._load()
|
|
|
|
def can_generate_metrics(self) -> bool:
|
|
return True
|
|
|
|
def _load(self):
|
|
"""Loads the Whisper model. Note that if this is the first time
|
|
this model is being run, it will take time to download."""
|
|
logger.debug("Loading Whisper model...")
|
|
self._model = WhisperModel(
|
|
self._model_name.value if isinstance(self._model_name, Enum) else self._model_name,
|
|
device=self._device,
|
|
compute_type=self._compute_type)
|
|
logger.debug("Loaded Whisper model")
|
|
|
|
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
|
"""Transcribes given audio using Whisper"""
|
|
if not self._model:
|
|
logger.error(f"{self} error: Whisper model not available")
|
|
yield ErrorFrame("Whisper model not available")
|
|
return
|
|
|
|
await self.start_ttfb_metrics()
|
|
|
|
|
|
audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0
|
|
|
|
segments, _ = await asyncio.to_thread(self._model.transcribe, audio_float)
|
|
text: str = ""
|
|
for segment in segments:
|
|
if segment.no_speech_prob < self._no_speech_prob:
|
|
text += f"{segment.text} "
|
|
|
|
if text:
|
|
await self.stop_ttfb_metrics()
|
|
logger.debug(f"Transcription: [{text}]")
|
|
yield TranscriptionFrame(text, "", int(time.time_ns() / 1000000))
|
|
|