|
from pywhispercpp.model import Model |
|
import soundfile |
|
import config |
|
import numpy as np |
|
from logging import getLogger |
|
|
|
logger = getLogger(__name__) |
|
|
|
|
|
class WhisperCPP: |
|
|
|
def __init__(self, source_lange: str='en', warmup=True) -> None: |
|
models_dir = config.MODEL_DIR.as_posix() |
|
if source_lange == "zh": |
|
whisper_model = config.WHISPER_MODEL_ZH |
|
else: |
|
whisper_model = config.WHISPER_MODEL_EN |
|
self.model = Model( |
|
model=whisper_model, |
|
models_dir=models_dir, |
|
print_realtime=False, |
|
print_progress=False, |
|
print_timestamps=False, |
|
translate=False, |
|
|
|
temperature=0., |
|
no_context=True |
|
) |
|
if warmup: |
|
self.warmup() |
|
|
|
|
|
def warmup(cls, warmup_steps=1): |
|
mel, _, = soundfile.read(f"{config.ASSERT_DIR}/jfk.flac") |
|
for _ in range(warmup_steps): |
|
cls.model.transcribe(mel, print_progress=False) |
|
|
|
@staticmethod |
|
def config_language(language): |
|
if language == "zh": |
|
return config.MAX_LENTH_ZH, config.WHISPER_PROMPT_ZH |
|
elif language == "en": |
|
return config.MAX_LENGTH_EN, config.WHISPER_PROMPT_EN |
|
raise ValueError(f"Unsupported language : {language}") |
|
|
|
def transcribe(self, audio_buffer:bytes, language): |
|
max_len, prompt = self.config_language(language) |
|
audio_buffer = np.frombuffer(audio_buffer, dtype=np.float32) |
|
try: |
|
output = self.model.transcribe( |
|
audio_buffer, |
|
initial_prompt=prompt, |
|
language=language, |
|
|
|
split_on_word=True, |
|
|
|
) |
|
return output |
|
except Exception as e: |
|
logger.error(e) |
|
return [] |