|
import os |
|
import torch |
|
import whisper |
|
import gradio as gr |
|
import torchaudio |
|
from abc import ABC, abstractmethod |
|
from typing import BinaryIO, Union, Tuple, List |
|
import numpy as np |
|
from datetime import datetime |
|
from faster_whisper.vad import VadOptions |
|
from dataclasses import astuple |
|
import gc |
|
from copy import deepcopy |
|
|
|
from modules.uvr.music_separator import MusicSeparator |
|
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH, |
|
UVR_MODELS_DIR) |
|
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, get_csv, write_file, safe_filename |
|
from modules.utils.youtube_manager import get_ytdata, get_ytaudio |
|
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml |
|
from modules.whisper.whisper_parameter import * |
|
from modules.diarize.diarizer import Diarizer |
|
from modules.vad.silero_vad import SileroVAD |
|
from modules.translation.nllb_inference import NLLBInference |
|
from modules.translation.nllb_inference import NLLB_AVAILABLE_LANGS |
|
|
|
class WhisperBase(ABC): |
|
def __init__(self, |
|
model_dir: str = WHISPER_MODELS_DIR, |
|
diarization_model_dir: str = DIARIZATION_MODELS_DIR, |
|
uvr_model_dir: str = UVR_MODELS_DIR, |
|
output_dir: str = OUTPUT_DIR, |
|
): |
|
self.model_dir = model_dir |
|
self.output_dir = output_dir |
|
os.makedirs(self.output_dir, exist_ok=True) |
|
os.makedirs(self.model_dir, exist_ok=True) |
|
self.diarizer = Diarizer( |
|
model_dir=diarization_model_dir |
|
) |
|
self.vad = SileroVAD() |
|
self.music_separator = MusicSeparator( |
|
model_dir=uvr_model_dir, |
|
output_dir=os.path.join(output_dir, "UVR") |
|
) |
|
|
|
self.model = None |
|
self.current_model_size = None |
|
self.available_models = whisper.available_models() |
|
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values())) |
|
|
|
self.translatable_models = whisper.available_models() |
|
self.device = self.get_device() |
|
self.available_compute_types = ["float16", "float32"] |
|
self.current_compute_type = "float16" if self.device == "cuda" else "float32" |
|
|
|
@abstractmethod |
|
def transcribe(self, |
|
audio: Union[str, BinaryIO, np.ndarray], |
|
progress: gr.Progress = gr.Progress(), |
|
*whisper_params, |
|
): |
|
"""Inference whisper model to transcribe""" |
|
pass |
|
|
|
@abstractmethod |
|
def update_model(self, |
|
model_size: str, |
|
compute_type: str, |
|
progress: gr.Progress = gr.Progress() |
|
): |
|
"""Initialize whisper model""" |
|
pass |
|
|
|
def run(self, |
|
audio: Union[str, BinaryIO, np.ndarray], |
|
progress: gr.Progress = gr.Progress(), |
|
add_timestamp: bool = True, |
|
*whisper_params, |
|
) -> Tuple[List[dict], float]: |
|
""" |
|
Run transcription with conditional pre-processing and post-processing. |
|
The VAD will be performed to remove noise from the audio input in pre-processing, if enabled. |
|
The diarization will be performed in post-processing, if enabled. |
|
|
|
Parameters |
|
---------- |
|
audio: Union[str, BinaryIO, np.ndarray] |
|
Audio input. This can be file path or binary type. |
|
progress: gr.Progress |
|
Indicator to show progress directly in gradio. |
|
add_timestamp: bool |
|
Whether to add a timestamp at the end of the filename. |
|
*whisper_params: tuple |
|
Parameters related with whisper. This will be dealt with "WhisperParameters" data class |
|
|
|
Returns |
|
---------- |
|
segments_result: List[dict] |
|
list of dicts that includes start, end timestamps and transcribed text |
|
elapsed_time: float |
|
elapsed time for running |
|
""" |
|
|
|
start_time = datetime.now() |
|
params = WhisperParameters.as_value(*whisper_params) |
|
|
|
|
|
default_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) |
|
whisper_params = default_params["whisper"] |
|
diarization_params = default_params["diarization"] |
|
bool_whisper_enable_offload = whisper_params["enable_offload"] |
|
bool_diarization_enable_offload = diarization_params["enable_offload"] |
|
|
|
if params.lang is None: |
|
pass |
|
elif params.lang == "Automatic Detection": |
|
params.lang = None |
|
else: |
|
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()} |
|
params.lang = language_code_dict[params.lang] |
|
|
|
if params.is_bgm_separate: |
|
music, audio, _ = self.music_separator.separate( |
|
audio=audio, |
|
model_name=params.uvr_model_size, |
|
device=params.uvr_device, |
|
segment_size=params.uvr_segment_size, |
|
save_file=params.uvr_save_file, |
|
progress=progress |
|
) |
|
|
|
if audio.ndim >= 2: |
|
audio = audio.mean(axis=1) |
|
if self.music_separator.audio_info is None: |
|
origin_sample_rate = 16000 |
|
else: |
|
origin_sample_rate = self.music_separator.audio_info.sample_rate |
|
audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate) |
|
|
|
if params.uvr_enable_offload: |
|
self.music_separator.offload() |
|
elapsed_time_bgm_sep = datetime.now() - start_time |
|
|
|
origin_audio = deepcopy(audio) |
|
|
|
if params.vad_filter: |
|
|
|
if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999: |
|
params.max_speech_duration_s = float('inf') |
|
|
|
progress(0, desc="Filtering silent parts from audio...") |
|
vad_options = VadOptions( |
|
threshold=params.threshold, |
|
min_speech_duration_ms=params.min_speech_duration_ms, |
|
max_speech_duration_s=params.max_speech_duration_s, |
|
min_silence_duration_ms=params.min_silence_duration_ms, |
|
speech_pad_ms=params.speech_pad_ms |
|
) |
|
|
|
vad_processed, speech_chunks = self.vad.run( |
|
audio=audio, |
|
vad_parameters=vad_options, |
|
progress=progress |
|
) |
|
|
|
if vad_processed.size > 0: |
|
audio = vad_processed |
|
else: |
|
params.vad_filter = False |
|
|
|
result, elapsed_time = self.transcribe( |
|
audio, |
|
progress, |
|
*astuple(params) |
|
) |
|
if bool_whisper_enable_offload: |
|
self.offload() |
|
|
|
if params.vad_filter: |
|
restored_result = self.vad.restore_speech_timestamps( |
|
segments=result, |
|
speech_chunks=speech_chunks, |
|
) |
|
if restored_result: |
|
result = restored_result |
|
else: |
|
print("VAD detected no speech segments in the audio.") |
|
|
|
if params.is_diarize: |
|
progress(0.99, desc="Diarizing speakers...") |
|
result, elapsed_time_diarization = self.diarizer.run( |
|
audio=origin_audio, |
|
use_auth_token=params.hf_token, |
|
transcribed_result=result, |
|
device=params.diarization_device |
|
) |
|
if bool_diarization_enable_offload: |
|
self.diarizer.offload() |
|
|
|
if not result: |
|
print(f"Whisper did not detected any speech segments in the audio.") |
|
result = list() |
|
|
|
progress(1.0, desc="Processing done!") |
|
total_elapsed_time = datetime.now() - start_time |
|
return result, elapsed_time |
|
|
|
def transcribe_file(self, |
|
files: Optional[List] = None, |
|
input_folder_path: Optional[str] = None, |
|
file_format: str = "SRT", |
|
add_timestamp: bool = True, |
|
translate_output: bool = False, |
|
translate_model: str = "", |
|
target_lang: str = "", |
|
progress=gr.Progress(), |
|
*whisper_params, |
|
) -> list: |
|
""" |
|
Write subtitle file from Files |
|
|
|
Parameters |
|
---------- |
|
files: list |
|
List of files to transcribe from gr.Files() |
|
input_folder_path: str |
|
Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and |
|
this will be used instead. |
|
file_format: str |
|
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt] |
|
add_timestamp: bool |
|
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename. |
|
translate_output: bool |
|
Translate output |
|
translate_model: str |
|
Translation model to use |
|
target_lang: str |
|
Target language to use |
|
progress: gr.Progress |
|
Indicator to show progress directly in gradio. |
|
*whisper_params: tuple |
|
Parameters related with whisper. This will be dealt with "WhisperParameters" data class |
|
|
|
Returns |
|
---------- |
|
result_str: |
|
Result of transcription to return to gr.Textbox() |
|
result_file_path: |
|
Output file path to return to gr.Files() |
|
""" |
|
try: |
|
if input_folder_path: |
|
files = get_media_files(input_folder_path) |
|
if isinstance(files, str): |
|
files = [files] |
|
if files and isinstance(files[0], gr.utils.NamedString): |
|
files = [file.name for file in files] |
|
|
|
|
|
files_info = {} |
|
files_to_download = {} |
|
time_start = datetime.now() |
|
|
|
|
|
params = WhisperParameters.as_value(*whisper_params) |
|
|
|
|
|
model = whisper.load_model("base") |
|
|
|
for file in files: |
|
|
|
|
|
mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device) |
|
_, probs = model.detect_language(mel) |
|
file_language = "" |
|
file_lang_probs = "" |
|
for key,value in whisper.tokenizer.LANGUAGES.items(): |
|
if key == str(max(probs, key=probs.get)): |
|
file_language = value.capitalize() |
|
for key_prob,value_prob in probs.items(): |
|
if key == key_prob: |
|
file_lang_probs = str((round(value_prob*100))) |
|
break |
|
break |
|
|
|
transcribed_segments, time_for_task = self.run( |
|
file, |
|
progress, |
|
add_timestamp, |
|
*whisper_params, |
|
) |
|
|
|
|
|
source_lang = file_language |
|
|
|
|
|
transcription_note = "" |
|
if params.is_translate: |
|
if source_lang != "English": |
|
transcription_note = "To English" |
|
source_lang = "English" |
|
else: |
|
transcription_note = "Already in English" |
|
|
|
|
|
translation_note = "" |
|
if translate_output: |
|
if source_lang != target_lang: |
|
self.nllb_inf = NLLBInference() |
|
if source_lang in NLLB_AVAILABLE_LANGS.keys(): |
|
transcribed_segments = self.nllb_inf.translate_text( |
|
input_list_dict=transcribed_segments, |
|
model_size=translate_model, |
|
src_lang=source_lang, |
|
tgt_lang=target_lang, |
|
speaker_diarization=params.is_diarize |
|
) |
|
translation_note = "To " + target_lang |
|
else: |
|
translation_note = source_lang + " not supported" |
|
else: |
|
translation_note = "Already in " + target_lang |
|
|
|
|
|
file_name, file_ext = os.path.splitext(os.path.basename(file)) |
|
subtitle = get_txt(transcribed_segments) |
|
files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "lang": file_language, "lang_prob": file_lang_probs, "input_source_file": (file_name+file_ext), "translation": translation_note, "transcription": transcription_note} |
|
|
|
|
|
file_name, file_ext = os.path.splitext(os.path.basename(file)) |
|
subtitle, file_path = self.generate_and_write_file( |
|
file_name=file_name, |
|
transcribed_segments=transcribed_segments, |
|
add_timestamp=add_timestamp, |
|
file_format="txt", |
|
output_dir=self.output_dir |
|
) |
|
files_to_download[file_name+"_txt"] = {"path": file_path} |
|
|
|
|
|
file_name, file_ext = os.path.splitext(os.path.basename(file)) |
|
subtitle, file_path = self.generate_and_write_file( |
|
file_name=file_name, |
|
transcribed_segments=transcribed_segments, |
|
add_timestamp=add_timestamp, |
|
file_format="srt", |
|
output_dir=self.output_dir |
|
) |
|
files_to_download[file_name+"_srt"] = {"path": file_path} |
|
|
|
|
|
file_name, file_ext = os.path.splitext(os.path.basename(file)) |
|
subtitle, file_path = self.generate_and_write_file( |
|
file_name=file_name, |
|
transcribed_segments=transcribed_segments, |
|
add_timestamp=add_timestamp, |
|
file_format="csv", |
|
output_dir=self.output_dir |
|
) |
|
files_to_download[file_name+"_csv"] = {"path": file_path} |
|
|
|
total_result = '' |
|
total_info = '' |
|
total_time = 0 |
|
for file_name, info in files_info.items(): |
|
total_result += f'{info["subtitle"]}' |
|
total_time += info["time_for_task"] |
|
total_info += f'Input file:\t\t{info["input_source_file"]}\nLanguage:\t{info["lang"]} (probability {info["lang_prob"]}%)\n' |
|
|
|
if params.is_translate: |
|
total_info += f'Translation:\t{info["transcription"]}\n\t⤷ Handled by OpenAI Whisper\n' |
|
|
|
if translate_output: |
|
total_info += f'Translation:\t{info["translation"]}\n\t⤷ Handled by Facebook NLLB\n' |
|
|
|
time_end = datetime.now() |
|
total_info += f"\nTotal processing time: {self.format_time((time_end-time_start).total_seconds())}" |
|
|
|
result_str = total_result.rstrip("\n") |
|
result_file_path = [info['path'] for info in files_to_download.values()] |
|
|
|
return [result_str,result_file_path,total_info] |
|
|
|
except Exception as e: |
|
print(f"Error transcribing file: {e}") |
|
finally: |
|
self.release_cuda_memory() |
|
|
|
def transcribe_mic(self, |
|
mic_audio: str, |
|
file_format: str = "SRT", |
|
add_timestamp: bool = True, |
|
progress=gr.Progress(), |
|
*whisper_params, |
|
) -> list: |
|
""" |
|
Write subtitle file from microphone |
|
|
|
Parameters |
|
---------- |
|
mic_audio: str |
|
Audio file path from gr.Microphone() |
|
file_format: str |
|
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt] |
|
add_timestamp: bool |
|
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename. |
|
progress: gr.Progress |
|
Indicator to show progress directly in gradio. |
|
*whisper_params: tuple |
|
Parameters related with whisper. This will be dealt with "WhisperParameters" data class |
|
|
|
Returns |
|
---------- |
|
result_str: |
|
Result of transcription to return to gr.Textbox() |
|
result_file_path: |
|
Output file path to return to gr.Files() |
|
""" |
|
try: |
|
progress(0, desc="Loading Audio...") |
|
transcribed_segments, time_for_task = self.run( |
|
mic_audio, |
|
progress, |
|
add_timestamp, |
|
*whisper_params, |
|
) |
|
progress(1, desc="Completed!") |
|
|
|
subtitle, result_file_path = self.generate_and_write_file( |
|
file_name="Mic", |
|
transcribed_segments=transcribed_segments, |
|
add_timestamp=add_timestamp, |
|
file_format=file_format, |
|
output_dir=self.output_dir |
|
) |
|
|
|
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}" |
|
return [result_str, result_file_path] |
|
except Exception as e: |
|
print(f"Error transcribing file: {e}") |
|
finally: |
|
self.release_cuda_memory() |
|
|
|
def transcribe_youtube(self, |
|
youtube_link: str, |
|
file_format: str = "SRT", |
|
add_timestamp: bool = True, |
|
progress=gr.Progress(), |
|
*whisper_params, |
|
) -> list: |
|
""" |
|
Write subtitle file from Youtube |
|
|
|
Parameters |
|
---------- |
|
youtube_link: str |
|
URL of the Youtube video to transcribe from gr.Textbox() |
|
file_format: str |
|
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt] |
|
add_timestamp: bool |
|
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename. |
|
progress: gr.Progress |
|
Indicator to show progress directly in gradio. |
|
*whisper_params: tuple |
|
Parameters related with whisper. This will be dealt with "WhisperParameters" data class |
|
|
|
Returns |
|
---------- |
|
result_str: |
|
Result of transcription to return to gr.Textbox() |
|
result_file_path: |
|
Output file path to return to gr.Files() |
|
""" |
|
try: |
|
progress(0, desc="Loading Audio from Youtube...") |
|
yt = get_ytdata(youtube_link) |
|
audio = get_ytaudio(yt) |
|
|
|
transcribed_segments, time_for_task = self.run( |
|
audio, |
|
progress, |
|
add_timestamp, |
|
*whisper_params, |
|
) |
|
|
|
progress(1, desc="Completed!") |
|
|
|
file_name = safe_filename(yt.title) |
|
subtitle, result_file_path = self.generate_and_write_file( |
|
file_name=file_name, |
|
transcribed_segments=transcribed_segments, |
|
add_timestamp=add_timestamp, |
|
file_format=file_format, |
|
output_dir=self.output_dir |
|
) |
|
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}" |
|
|
|
if os.path.exists(audio): |
|
os.remove(audio) |
|
|
|
return [result_str, result_file_path] |
|
|
|
except Exception as e: |
|
print(f"Error transcribing file: {e}") |
|
finally: |
|
self.release_cuda_memory() |
|
|
|
@staticmethod |
|
def generate_and_write_file(file_name: str, |
|
transcribed_segments: list, |
|
add_timestamp: bool, |
|
file_format: str, |
|
output_dir: str |
|
) -> str: |
|
""" |
|
Writes subtitle file |
|
|
|
Parameters |
|
---------- |
|
file_name: str |
|
Output file name |
|
transcribed_segments: list |
|
Text segments transcribed from audio |
|
add_timestamp: bool |
|
Determines whether to add a timestamp to the end of the filename. |
|
file_format: str |
|
File format to write. Supported formats: [SRT, WebVTT, txt, csv] |
|
output_dir: str |
|
Directory path of the output |
|
|
|
Returns |
|
---------- |
|
content: str |
|
Result of the transcription |
|
output_path: str |
|
output file path |
|
""" |
|
if add_timestamp: |
|
|
|
timestamp = datetime.now().strftime("%Y%m%d %H%M%S") |
|
output_path = os.path.join(output_dir, f"{file_name} - {timestamp}") |
|
else: |
|
output_path = os.path.join(output_dir, f"{file_name}") |
|
|
|
file_format = file_format.strip().lower() |
|
if file_format == "srt": |
|
content = get_srt(transcribed_segments) |
|
output_path += '.srt' |
|
|
|
elif file_format == "webvtt": |
|
content = get_vtt(transcribed_segments) |
|
output_path += '.vtt' |
|
|
|
elif file_format == "txt": |
|
content = get_txt(transcribed_segments) |
|
output_path += '.txt' |
|
|
|
elif file_format == "csv": |
|
content = get_csv(transcribed_segments) |
|
output_path += '.csv' |
|
|
|
write_file(content, output_path) |
|
return content, output_path |
|
|
|
def offload(self): |
|
"""Offload the model and free up the memory""" |
|
if self.model is not None: |
|
del self.model |
|
self.model = None |
|
if self.device == "cuda": |
|
self.release_cuda_memory() |
|
gc.collect() |
|
|
|
@staticmethod |
|
def format_time(elapsed_time: float) -> str: |
|
""" |
|
Get {hours} {minutes} {seconds} time format string |
|
|
|
Parameters |
|
---------- |
|
elapsed_time: str |
|
Elapsed time for transcription |
|
|
|
Returns |
|
---------- |
|
Time format string |
|
""" |
|
hours, rem = divmod(elapsed_time, 3600) |
|
minutes, seconds = divmod(rem, 60) |
|
|
|
time_str = "" |
|
|
|
hours = round(hours) |
|
if hours: |
|
if hours == 1: |
|
time_str += f"{hours} hour " |
|
else: |
|
time_str += f"{hours} hours " |
|
|
|
minutes = round(minutes) |
|
if minutes: |
|
if minutes == 1: |
|
time_str += f"{minutes} minute " |
|
else: |
|
time_str += f"{minutes} minutes " |
|
|
|
seconds = round(seconds) |
|
if seconds == 1: |
|
time_str += f"{seconds} second" |
|
else: |
|
time_str += f"{seconds} seconds" |
|
|
|
return time_str.strip() |
|
|
|
@staticmethod |
|
def get_device(): |
|
if torch.cuda.is_available(): |
|
return "cuda" |
|
elif torch.backends.mps.is_available(): |
|
if not WhisperBase.is_sparse_api_supported(): |
|
|
|
return "cpu" |
|
return "mps" |
|
else: |
|
return "cpu" |
|
|
|
@staticmethod |
|
def is_sparse_api_supported(): |
|
if not torch.backends.mps.is_available(): |
|
return False |
|
|
|
try: |
|
device = torch.device("mps") |
|
sparse_tensor = torch.sparse_coo_tensor( |
|
indices=torch.tensor([[0, 1], [2, 3]]), |
|
values=torch.tensor([1, 2]), |
|
size=(4, 4), |
|
device=device |
|
) |
|
return True |
|
except RuntimeError: |
|
return False |
|
|
|
@staticmethod |
|
def release_cuda_memory(): |
|
"""Release memory""" |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
torch.cuda.reset_max_memory_allocated() |
|
|
|
@staticmethod |
|
def remove_input_files(file_paths: List[str]): |
|
"""Remove gradio cached files""" |
|
if not file_paths: |
|
return |
|
|
|
for file_path in file_paths: |
|
if file_path and os.path.exists(file_path): |
|
os.remove(file_path) |
|
|
|
@staticmethod |
|
def cache_parameters( |
|
whisper_params: WhisperValues, |
|
add_timestamp: bool |
|
): |
|
"""cache parameters to the yaml file""" |
|
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) |
|
cached_whisper_param = whisper_params.to_yaml() |
|
cached_yaml = {**cached_params, **cached_whisper_param} |
|
cached_yaml["whisper"]["add_timestamp"] = add_timestamp |
|
|
|
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH) |
|
|
|
@staticmethod |
|
def resample_audio(audio: Union[str, np.ndarray], |
|
new_sample_rate: int = 16000, |
|
original_sample_rate: Optional[int] = None,) -> np.ndarray: |
|
"""Resamples audio to 16k sample rate, standard on Whisper model""" |
|
if isinstance(audio, str): |
|
audio, original_sample_rate = torchaudio.load(audio) |
|
else: |
|
if original_sample_rate is None: |
|
raise ValueError("original_sample_rate must be provided when audio is numpy array.") |
|
audio = torch.from_numpy(audio) |
|
resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate) |
|
resampled_audio = resampler(audio).numpy() |
|
return resampled_audio |
|
|