|  | import os | 
					
						
						|  | import time | 
					
						
						|  | import numpy as np | 
					
						
						|  | from typing import BinaryIO, Union, Tuple, List | 
					
						
						|  | import torch | 
					
						
						|  | from transformers import pipeline | 
					
						
						|  | from transformers.utils import is_flash_attn_2_available | 
					
						
						|  | import gradio as gr | 
					
						
						|  | from huggingface_hub import hf_hub_download | 
					
						
						|  | import whisper | 
					
						
						|  | from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn | 
					
						
						|  | from argparse import Namespace | 
					
						
						|  |  | 
					
						
						|  | from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) | 
					
						
						|  | from modules.whisper.whisper_parameter import * | 
					
						
						|  | from modules.whisper.whisper_base import WhisperBase | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class InsanelyFastWhisperInference(WhisperBase): | 
					
						
						|  | def __init__(self, | 
					
						
						|  | model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR, | 
					
						
						|  | diarization_model_dir: str = DIARIZATION_MODELS_DIR, | 
					
						
						|  | uvr_model_dir: str = UVR_MODELS_DIR, | 
					
						
						|  | output_dir: str = OUTPUT_DIR, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__( | 
					
						
						|  | model_dir=model_dir, | 
					
						
						|  | output_dir=output_dir, | 
					
						
						|  | diarization_model_dir=diarization_model_dir, | 
					
						
						|  | uvr_model_dir=uvr_model_dir | 
					
						
						|  | ) | 
					
						
						|  | self.model_dir = model_dir | 
					
						
						|  | os.makedirs(self.model_dir, exist_ok=True) | 
					
						
						|  |  | 
					
						
						|  | openai_models = whisper.available_models() | 
					
						
						|  | distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"] | 
					
						
						|  | self.available_models = openai_models + distil_models | 
					
						
						|  | self.available_compute_types = ["float16"] | 
					
						
						|  |  | 
					
						
						|  | def transcribe(self, | 
					
						
						|  | audio: Union[str, np.ndarray, torch.Tensor], | 
					
						
						|  | progress: gr.Progress = gr.Progress(), | 
					
						
						|  | *whisper_params, | 
					
						
						|  | ) -> Tuple[List[dict], float]: | 
					
						
						|  | """ | 
					
						
						|  | transcribe method for faster-whisper. | 
					
						
						|  |  | 
					
						
						|  | Parameters | 
					
						
						|  | ---------- | 
					
						
						|  | audio: Union[str, BinaryIO, np.ndarray] | 
					
						
						|  | Audio path or file binary or Audio numpy array | 
					
						
						|  | progress: gr.Progress | 
					
						
						|  | Indicator to show progress directly in gradio. | 
					
						
						|  | *whisper_params: tuple | 
					
						
						|  | Parameters related with whisper. This will be dealt with "WhisperParameters" data class | 
					
						
						|  |  | 
					
						
						|  | Returns | 
					
						
						|  | ---------- | 
					
						
						|  | segments_result: List[dict] | 
					
						
						|  | list of dicts that includes start, end timestamps and transcribed text | 
					
						
						|  | elapsed_time: float | 
					
						
						|  | elapsed time for transcription | 
					
						
						|  | """ | 
					
						
						|  | start_time = time.time() | 
					
						
						|  | params = WhisperParameters.as_value(*whisper_params) | 
					
						
						|  |  | 
					
						
						|  | if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: | 
					
						
						|  | self.update_model(params.model_size, params.compute_type, progress) | 
					
						
						|  |  | 
					
						
						|  | progress(0, desc="Transcribing...Progress is not shown in insanely-fast-whisper.") | 
					
						
						|  | with Progress( | 
					
						
						|  | TextColumn("[progress.description]{task.description}"), | 
					
						
						|  | BarColumn(style="yellow1", pulse_style="white"), | 
					
						
						|  | TimeElapsedColumn(), | 
					
						
						|  | ) as progress: | 
					
						
						|  | progress.add_task("[yellow]Transcribing...", total=None) | 
					
						
						|  |  | 
					
						
						|  | kwargs = { | 
					
						
						|  | "no_speech_threshold": params.no_speech_threshold, | 
					
						
						|  | "temperature": params.temperature, | 
					
						
						|  | "compression_ratio_threshold": params.compression_ratio_threshold, | 
					
						
						|  | "logprob_threshold": params.log_prob_threshold, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | if self.current_model_size.endswith(".en"): | 
					
						
						|  | pass | 
					
						
						|  | else: | 
					
						
						|  | kwargs["language"] = params.lang | 
					
						
						|  | kwargs["task"] = "translate" if params.is_translate else "transcribe" | 
					
						
						|  |  | 
					
						
						|  | segments = self.model( | 
					
						
						|  | inputs=audio, | 
					
						
						|  | return_timestamps=True, | 
					
						
						|  | chunk_length_s=params.chunk_length, | 
					
						
						|  | batch_size=params.batch_size, | 
					
						
						|  | generate_kwargs=kwargs | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | segments_result = self.format_result( | 
					
						
						|  | transcribed_result=segments, | 
					
						
						|  | ) | 
					
						
						|  | elapsed_time = time.time() - start_time | 
					
						
						|  | return segments_result, elapsed_time | 
					
						
						|  |  | 
					
						
						|  | def update_model(self, | 
					
						
						|  | model_size: str, | 
					
						
						|  | compute_type: str, | 
					
						
						|  | progress: gr.Progress = gr.Progress(), | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Update current model setting | 
					
						
						|  |  | 
					
						
						|  | Parameters | 
					
						
						|  | ---------- | 
					
						
						|  | model_size: str | 
					
						
						|  | Size of whisper model | 
					
						
						|  | compute_type: str | 
					
						
						|  | Compute type for transcription. | 
					
						
						|  | see more info : https://opennmt.net/CTranslate2/quantization.html | 
					
						
						|  | progress: gr.Progress | 
					
						
						|  | Indicator to show progress directly in gradio. | 
					
						
						|  | """ | 
					
						
						|  | progress(0, desc="Initializing Model...") | 
					
						
						|  | model_path = os.path.join(self.model_dir, model_size) | 
					
						
						|  | if not os.path.isdir(model_path) or not os.listdir(model_path): | 
					
						
						|  | self.download_model( | 
					
						
						|  | model_size=model_size, | 
					
						
						|  | download_root=model_path, | 
					
						
						|  | progress=progress | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.current_compute_type = compute_type | 
					
						
						|  | self.current_model_size = model_size | 
					
						
						|  | self.model = pipeline( | 
					
						
						|  | "automatic-speech-recognition", | 
					
						
						|  | model=os.path.join(self.model_dir, model_size), | 
					
						
						|  | torch_dtype=self.current_compute_type, | 
					
						
						|  | device=self.device, | 
					
						
						|  | model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"}, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def format_result( | 
					
						
						|  | transcribed_result: dict | 
					
						
						|  | ) -> List[dict]: | 
					
						
						|  | """ | 
					
						
						|  | Format the transcription result of insanely_fast_whisper as the same with other implementation. | 
					
						
						|  |  | 
					
						
						|  | Parameters | 
					
						
						|  | ---------- | 
					
						
						|  | transcribed_result: dict | 
					
						
						|  | Transcription result of the insanely_fast_whisper | 
					
						
						|  |  | 
					
						
						|  | Returns | 
					
						
						|  | ---------- | 
					
						
						|  | result: List[dict] | 
					
						
						|  | Formatted result as the same with other implementation | 
					
						
						|  | """ | 
					
						
						|  | result = transcribed_result["chunks"] | 
					
						
						|  | for item in result: | 
					
						
						|  | start, end = item["timestamp"][0], item["timestamp"][1] | 
					
						
						|  | if end is None: | 
					
						
						|  | end = start | 
					
						
						|  | item["start"] = start | 
					
						
						|  | item["end"] = end | 
					
						
						|  | return result | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def download_model( | 
					
						
						|  | model_size: str, | 
					
						
						|  | download_root: str, | 
					
						
						|  | progress: gr.Progress | 
					
						
						|  | ): | 
					
						
						|  | progress(0, 'Initializing model..') | 
					
						
						|  | print(f'Downloading {model_size} to "{download_root}"....') | 
					
						
						|  |  | 
					
						
						|  | os.makedirs(download_root, exist_ok=True) | 
					
						
						|  | download_list = [ | 
					
						
						|  | "model.safetensors", | 
					
						
						|  | "config.json", | 
					
						
						|  | "generation_config.json", | 
					
						
						|  | "preprocessor_config.json", | 
					
						
						|  | "tokenizer.json", | 
					
						
						|  | "tokenizer_config.json", | 
					
						
						|  | "added_tokens.json", | 
					
						
						|  | "special_tokens_map.json", | 
					
						
						|  | "vocab.json", | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | if model_size.startswith("distil"): | 
					
						
						|  | repo_id = f"distil-whisper/{model_size}" | 
					
						
						|  | else: | 
					
						
						|  | repo_id = f"openai/whisper-{model_size}" | 
					
						
						|  | for item in download_list: | 
					
						
						|  | hf_hub_download(repo_id=repo_id, filename=item, local_dir=download_root) | 
					
						
						|  |  |