Spaces:
Running
Running
import whisper | |
import gradio as gr | |
import time | |
from typing import BinaryIO, Union, Tuple, List | |
import numpy as np | |
import torch | |
import os | |
from argparse import Namespace | |
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR) | |
from modules.whisper.whisper_base import WhisperBase | |
from modules.whisper.whisper_parameter import * | |
class WhisperInference(WhisperBase): | |
def __init__(self, | |
model_dir: str = WHISPER_MODELS_DIR, | |
diarization_model_dir: str = DIARIZATION_MODELS_DIR, | |
uvr_model_dir: str = UVR_MODELS_DIR, | |
output_dir: str = OUTPUT_DIR, | |
): | |
super().__init__( | |
model_dir=model_dir, | |
output_dir=output_dir, | |
diarization_model_dir=diarization_model_dir, | |
uvr_model_dir=uvr_model_dir | |
) | |
def transcribe(self, | |
audio: Union[str, np.ndarray, torch.Tensor], | |
progress: gr.Progress = gr.Progress(), | |
*whisper_params, | |
) -> Tuple[List[dict], float]: | |
""" | |
transcribe method for faster-whisper. | |
Parameters | |
---------- | |
audio: Union[str, BinaryIO, np.ndarray] | |
Audio path or file binary or Audio numpy array | |
progress: gr.Progress | |
Indicator to show progress directly in gradio. | |
*whisper_params: tuple | |
Parameters related with whisper. This will be dealt with "WhisperParameters" data class | |
Returns | |
---------- | |
segments_result: List[dict] | |
list of dicts that includes start, end timestamps and transcribed text | |
elapsed_time: float | |
elapsed time for transcription | |
""" | |
start_time = time.time() | |
params = WhisperParameters.as_value(*whisper_params) | |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: | |
self.update_model(params.model_size, params.compute_type, progress) | |
def progress_callback(progress_value): | |
progress(progress_value, desc="Transcribing..") | |
segments_result = self.model.transcribe(audio=audio, | |
language=params.lang, | |
verbose=False, | |
beam_size=params.beam_size, | |
logprob_threshold=params.log_prob_threshold, | |
no_speech_threshold=params.no_speech_threshold, | |
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", | |
fp16=True if params.compute_type == "float16" else False, | |
best_of=params.best_of, | |
patience=params.patience, | |
temperature=params.temperature, | |
compression_ratio_threshold=params.compression_ratio_threshold, | |
progress_callback=progress_callback,)["segments"] | |
elapsed_time = time.time() - start_time | |
return segments_result, elapsed_time | |
def update_model(self, | |
model_size: str, | |
compute_type: str, | |
progress: gr.Progress = gr.Progress(), | |
): | |
""" | |
Update current model setting | |
Parameters | |
---------- | |
model_size: str | |
Size of whisper model | |
compute_type: str | |
Compute type for transcription. | |
see more info : https://opennmt.net/CTranslate2/quantization.html | |
progress: gr.Progress | |
Indicator to show progress directly in gradio. | |
""" | |
progress(0, desc="Initializing Model..") | |
self.current_compute_type = compute_type | |
self.current_model_size = model_size | |
self.model = whisper.load_model( | |
name=model_size, | |
device=self.device, | |
download_root=self.model_dir | |
) |